From 6d9668f62fc32050301cd397e9598ae45654fcaf Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 11 Aug 2023 14:06:05 +0200 Subject: [PATCH 1/3] ADR 019: revamp auth managers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR updates the preview feature "re-auth" (introduced in [PR #890](https://github.com/neo4j/neo4j-python-driver/pull/890)) significantly. The changes allow for catering to a wider range of use cases including simple password rotation. (⚠️ Breaking) changes: * Removed `TokenExpiredRetryable` exception. Even though it wasn't marked preview, it was introduced with and only used for re-auth. It now longer serves any purpose. * The `AuthManager` and `AsyncAuthManager` abstract classes were changed. The method `on_auth_expired(self, auth: _TAuth) -> None` was removed in favor of `def handle_security_exception(self, auth: _TAuth, error: Neo4jError) -> bool`. See the API docs for more details. * The factories in `AsyncAuthManagers`a nd `AuthManagers` were changed. * `expiration_based` was renamed to `bearer`. * `basic` was added to cater for password rotation. --- docs/source/api.rst | 3 - src/neo4j/_async/auth_management.py | 118 ++++++++++++++++++--- src/neo4j/_async/io/_pool.py | 18 ++-- src/neo4j/_auth_management.py | 40 +++++-- src/neo4j/_sync/auth_management.py | 118 ++++++++++++++++++--- src/neo4j/_sync/io/_pool.py | 18 ++-- src/neo4j/exceptions.py | 28 ++--- testkitbackend/_async/backend.py | 3 + testkitbackend/_async/requests.py | 68 +++++++++--- testkitbackend/_sync/backend.py | 3 + testkitbackend/_sync/requests.py | 68 +++++++++--- testkitbackend/test_config.json | 1 + tests/unit/async_/test_auth_manager.py | 141 ++++++++++++++++++++----- tests/unit/sync/test_auth_manager.py | 141 ++++++++++++++++++++----- 14 files changed, 607 insertions(+), 161 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 61e25d0b2..a9d3ce31c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1838,9 +1838,6 @@ Server-side errors .. autoexception:: neo4j.exceptions.TokenExpired() :show-inheritance: -.. autoexception:: neo4j.exceptions.TokenExpiredRetryable() - :show-inheritance: - .. autoexception:: neo4j.exceptions.Forbidden() :show-inheritance: diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py index 6a32d7b40..a6d8d2861 100644 --- a/src/neo4j/_async/auth_management.py +++ b/src/neo4j/_async/auth_management.py @@ -21,8 +21,8 @@ # make sure TAuth is resolved in the docs, else they're pretty useless -import time import typing as t +import warnings from logging import getLogger from .._async_compat.concurrency import AsyncLock @@ -31,12 +31,16 @@ expiring_auth_has_expired, ExpiringAuth, ) -from .._meta import preview +from .._meta import ( + preview, + PreviewWarning, +) # work around for https://github.com/sphinx-doc/sphinx/pull/10880 # make sure TAuth is resolved in the docs, else they're pretty useless # if t.TYPE_CHECKING: from ..api import _TAuth +from ..exceptions import Neo4jError log = getLogger("neo4j") @@ -51,21 +55,25 @@ def __init__(self, auth: _TAuth) -> None: async def get_auth(self) -> _TAuth: return self._auth - async def on_auth_expired(self, auth: _TAuth) -> None: - pass + async def handle_security_exception( + self, auth: _TAuth, error: Neo4jError + ) -> bool: + return False -class AsyncExpirationBasedAuthManager(AsyncAuthManager): +class Neo4jAuthTokenManager(AsyncAuthManager): _current_auth: t.Optional[ExpiringAuth] _provider: t.Callable[[], t.Awaitable[ExpiringAuth]] + _handled_codes: t.FrozenSet[str] _lock: AsyncLock - def __init__( self, - provider: t.Callable[[], t.Awaitable[ExpiringAuth]] + provider: t.Callable[[], t.Awaitable[ExpiringAuth]], + handled_codes: t.FrozenSet[str] ) -> None: self._provider = provider + self._handled_codes = handled_codes self._current_auth = None self._lock = AsyncLock() @@ -81,18 +89,25 @@ async def get_auth(self) -> _TAuth: async with self._lock: auth = self._current_auth if auth is None or expiring_auth_has_expired(auth): - log.debug("[ ] _: refreshing (time out)") + log.debug("[ ] _: refreshing (%s)", + "init" if auth is None else "time out") await self._refresh_auth() auth = self._current_auth assert auth is not None return auth.auth - async def on_auth_expired(self, auth: _TAuth) -> None: + async def handle_security_exception( + self, auth: _TAuth, error: Neo4jError + ) -> bool: + if error.code not in self._handled_codes: + return False async with self._lock: cur_auth = self._current_auth if cur_auth is not None and cur_auth.auth == auth: - log.debug("[ ] _: refreshing (error)") + log.debug("[ ] _: refreshing (error %s)", + error.code) await self._refresh_auth() + return True class AsyncAuthManagers: @@ -103,6 +118,11 @@ class AsyncAuthManagers: See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features .. versionadded:: 5.8 + + .. versionchanged:: 5.12 + + * Method ``expiration_based()`` was renamed to :meth:`bearer`. + * Added :meth:`basic`. """ @staticmethod @@ -139,10 +159,72 @@ def static(auth: _TAuth) -> AsyncAuthManager: @staticmethod @preview("Auth managers are a preview feature.") - def expiration_based( + def basic( + provider: t.Callable[[], t.Awaitable[_TAuth]] + ) -> AsyncAuthManager: + """Create an auth manager handling basic auth password rotation. + + .. warning:: + + The provider function **must not** interact with the driver in any + way as this can cause deadlocks and undefined behaviour. + + The provider function must only ever return auth information + belonging to the same identity. + Switching identities is undefined behavior. + You may use session-level authentication for such use-cases + :ref:`session-auth-ref`. + + Example:: + + import neo4j + from neo4j.auth_management import ( + AsyncAuthManagers, + ExpiringAuth, + ) + + + async def auth_provider(): + # some way of getting a token + user, password = await get_current_auth() + return (user, password) + + + with neo4j.GraphDatabase.driver( + "neo4j://example.com:7687", + auth=AsyncAuthManagers.basic(auth_provider) + ) as driver: + ... # do stuff + + :param provider: + A callable that provides a :class:`.ExpiringAuth` instance. + + :returns: + An instance of an implementation of :class:`.AsyncAuthManager` that + returns auth info from the given provider and refreshes it, calling + the provider again, when the auth info expires (either because it's + reached its expiry time or because the server flagged it as + expired). + + .. versionadded:: 5.12 + """ + handled_codes = frozenset(("Neo.ClientError.Security.Unauthorized",)) + + async def wrapped_provider() -> ExpiringAuth: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + message=r"^Auth managers\b.*", + category=PreviewWarning) + return ExpiringAuth(await provider()) + + return Neo4jAuthTokenManager(wrapped_provider, handled_codes) + + @staticmethod + @preview("Auth managers are a preview feature.") + def bearer( provider: t.Callable[[], t.Awaitable[ExpiringAuth]] ) -> AsyncAuthManager: - """Create an auth manager for potentially expiring auth info. + """Create an auth manager for potentially expiring bearer auth tokens. .. warning:: @@ -165,7 +247,7 @@ def expiration_based( async def auth_provider(): - # some way to getting a token + # some way of getting a token sso_token = await get_sso_token() # assume we know our tokens expire every 60 seconds expires_in = 60 @@ -180,7 +262,7 @@ async def auth_provider(): with neo4j.GraphDatabase.driver( "neo4j://example.com:7687", - auth=AsyncAuthManagers.temporal(auth_provider) + auth=AsyncAuthManagers.bearer(auth_provider) ) as driver: ... # do stuff @@ -194,6 +276,10 @@ async def auth_provider(): reached its expiry time or because the server flagged it as expired). - + .. versionadded:: 5.12 """ - return AsyncExpirationBasedAuthManager(provider) + handled_codes = frozenset(( + "Neo.ClientError.Security.TokenExpired", + "Neo.ClientError.Security.Unauthorized", + )) + return Neo4jAuthTokenManager(provider, handled_codes) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index e854fe987..2844bc9b8 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -48,7 +48,6 @@ ) from ..._exceptions import BoltError from ..._routing import RoutingTable -from ..._sync.auth_management import StaticAuthManager from ...api import ( READ_ACCESS, WRITE_ACCESS, @@ -65,11 +64,8 @@ ReadServiceUnavailable, ServiceUnavailable, SessionExpired, - TokenExpired, - TokenExpiredRetryable, WriteServiceUnavailable, ) -from ..auth_management import AsyncStaticAuthManager from ._bolt import AsyncBolt @@ -467,15 +463,13 @@ async def on_neo4j_error(self, error, connection): with self.lock: for connection in self.connections.get(address, ()): connection.mark_unauthenticated() - if error._requires_new_credentials(): - await AsyncUtil.callback( - connection.auth_manager.on_auth_expired, - connection.auth + if error._has_security_code(): + handled = await AsyncUtil.callback( + connection.auth_manager.handle_security_exception, + connection.auth, error ) - if (isinstance(error, TokenExpired) - and not isinstance(self.pool_config.auth, (AsyncStaticAuthManager, - StaticAuthManager))): - error.__class__ = TokenExpiredRetryable + if handled: + error._retryable = True async def close(self): """ Close all connections and empty the pool. diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py index 31f9fe1ed..577fcf2da 100644 --- a/src/neo4j/_auth_management.py +++ b/src/neo4j/_auth_management.py @@ -32,6 +32,7 @@ PreviewWarning, ) from .api import _TAuth +from .exceptions import Neo4jError @preview("Auth managers are a preview feature.") @@ -128,6 +129,13 @@ class AuthManager(metaclass=abc.ABCMeta): .. seealso:: :class:`.AuthManagers` .. versionadded:: 5.8 + + .. versionchanged:: 5.12 + ``on_auth_expired`` was removed from the interface and replaced by + :meth:`handle_security_exception`. The new method is called when the + server returns any `Neo.ClientError.Security.*` error. It's signature + differs in that it additionally received the error returned by the + server and returns a boolean indicating whether the error was handled. """ @abc.abstractmethod @@ -148,15 +156,27 @@ def get_auth(self) -> _TAuth: ... @abc.abstractmethod - def on_auth_expired(self, auth: _TAuth) -> None: - """Handle the server indicating expired authentication information. + def handle_security_exception( + self, auth: _TAuth, error: Neo4jError + ) -> bool: + """Handle the server indicating authentication failure. - The driver will call this method when the server indicates that the - provided authentication information is no longer valid. + The driver will call this method when the server returns any + `Neo.ClientError.Security.*` error. The error will then be processed + further as usual. :param auth: - The authentication information that the server flagged as no longer - valid. + The authentication information that was used when the server + returned the error. + :param error: + The error returned by the server. + + :returns: + Whether the error was handled (:const:`True`), in which case the + driver will mark the error as retryable + (see :meth:`.Neo4jError.is_retryable`). + + .. versionadded:: 5.12 """ ... @@ -171,6 +191,10 @@ class AsyncAuthManager(metaclass=abc.ABCMeta): .. seealso:: :class:`.AuthManager` .. versionadded:: 5.8 + + .. versionchanged:: 5.12 + ``on_auth_expired`` was removed from the interface and replaced by + :meth:`handle_security_exception`. See :class:`.AuthManager`. """ @abc.abstractmethod @@ -182,7 +206,9 @@ async def get_auth(self) -> _TAuth: ... @abc.abstractmethod - async def on_auth_expired(self, auth: _TAuth) -> None: + async def handle_security_exception( + self, auth: _TAuth, error: Neo4jError + ) -> bool: """Async version of :meth:`.AuthManager.on_auth_expired`. .. seealso:: :meth:`.AuthManager.on_auth_expired` diff --git a/src/neo4j/_sync/auth_management.py b/src/neo4j/_sync/auth_management.py index d57108a6f..0968459c1 100644 --- a/src/neo4j/_sync/auth_management.py +++ b/src/neo4j/_sync/auth_management.py @@ -21,8 +21,8 @@ # make sure TAuth is resolved in the docs, else they're pretty useless -import time import typing as t +import warnings from logging import getLogger from .._async_compat.concurrency import Lock @@ -31,12 +31,16 @@ expiring_auth_has_expired, ExpiringAuth, ) -from .._meta import preview +from .._meta import ( + preview, + PreviewWarning, +) # work around for https://github.com/sphinx-doc/sphinx/pull/10880 # make sure TAuth is resolved in the docs, else they're pretty useless # if t.TYPE_CHECKING: from ..api import _TAuth +from ..exceptions import Neo4jError log = getLogger("neo4j") @@ -51,21 +55,25 @@ def __init__(self, auth: _TAuth) -> None: def get_auth(self) -> _TAuth: return self._auth - def on_auth_expired(self, auth: _TAuth) -> None: - pass + def handle_security_exception( + self, auth: _TAuth, error: Neo4jError + ) -> bool: + return False -class ExpirationBasedAuthManager(AuthManager): +class Neo4jAuthTokenManager(AuthManager): _current_auth: t.Optional[ExpiringAuth] _provider: t.Callable[[], t.Union[ExpiringAuth]] + _handled_codes: t.FrozenSet[str] _lock: Lock - def __init__( self, - provider: t.Callable[[], t.Union[ExpiringAuth]] + provider: t.Callable[[], t.Union[ExpiringAuth]], + handled_codes: t.FrozenSet[str] ) -> None: self._provider = provider + self._handled_codes = handled_codes self._current_auth = None self._lock = Lock() @@ -81,18 +89,25 @@ def get_auth(self) -> _TAuth: with self._lock: auth = self._current_auth if auth is None or expiring_auth_has_expired(auth): - log.debug("[ ] _: refreshing (time out)") + log.debug("[ ] _: refreshing (%s)", + "init" if auth is None else "time out") self._refresh_auth() auth = self._current_auth assert auth is not None return auth.auth - def on_auth_expired(self, auth: _TAuth) -> None: + def handle_security_exception( + self, auth: _TAuth, error: Neo4jError + ) -> bool: + if error.code not in self._handled_codes: + return False with self._lock: cur_auth = self._current_auth if cur_auth is not None and cur_auth.auth == auth: - log.debug("[ ] _: refreshing (error)") + log.debug("[ ] _: refreshing (error %s)", + error.code) self._refresh_auth() + return True class AuthManagers: @@ -103,6 +118,11 @@ class AuthManagers: See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features .. versionadded:: 5.8 + + .. versionchanged:: 5.12 + + * Method ``expiration_based()`` was renamed to :meth:`bearer`. + * Added :meth:`basic`. """ @staticmethod @@ -139,10 +159,72 @@ def static(auth: _TAuth) -> AuthManager: @staticmethod @preview("Auth managers are a preview feature.") - def expiration_based( + def basic( + provider: t.Callable[[], t.Union[_TAuth]] + ) -> AuthManager: + """Create an auth manager handling basic auth password rotation. + + .. warning:: + + The provider function **must not** interact with the driver in any + way as this can cause deadlocks and undefined behaviour. + + The provider function must only ever return auth information + belonging to the same identity. + Switching identities is undefined behavior. + You may use session-level authentication for such use-cases + :ref:`session-auth-ref`. + + Example:: + + import neo4j + from neo4j.auth_management import ( + AuthManagers, + ExpiringAuth, + ) + + + def auth_provider(): + # some way of getting a token + user, password = get_current_auth() + return (user, password) + + + with neo4j.GraphDatabase.driver( + "neo4j://example.com:7687", + auth=AuthManagers.basic(auth_provider) + ) as driver: + ... # do stuff + + :param provider: + A callable that provides a :class:`.ExpiringAuth` instance. + + :returns: + An instance of an implementation of :class:`.AuthManager` that + returns auth info from the given provider and refreshes it, calling + the provider again, when the auth info expires (either because it's + reached its expiry time or because the server flagged it as + expired). + + .. versionadded:: 5.12 + """ + handled_codes = frozenset(("Neo.ClientError.Security.Unauthorized",)) + + def wrapped_provider() -> ExpiringAuth: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + message=r"^Auth managers\b.*", + category=PreviewWarning) + return ExpiringAuth(provider()) + + return Neo4jAuthTokenManager(wrapped_provider, handled_codes) + + @staticmethod + @preview("Auth managers are a preview feature.") + def bearer( provider: t.Callable[[], t.Union[ExpiringAuth]] ) -> AuthManager: - """Create an auth manager for potentially expiring auth info. + """Create an auth manager for potentially expiring bearer auth tokens. .. warning:: @@ -165,7 +247,7 @@ def expiration_based( def auth_provider(): - # some way to getting a token + # some way of getting a token sso_token = get_sso_token() # assume we know our tokens expire every 60 seconds expires_in = 60 @@ -180,7 +262,7 @@ def auth_provider(): with neo4j.GraphDatabase.driver( "neo4j://example.com:7687", - auth=AuthManagers.temporal(auth_provider) + auth=AuthManagers.bearer(auth_provider) ) as driver: ... # do stuff @@ -194,6 +276,10 @@ def auth_provider(): reached its expiry time or because the server flagged it as expired). - + .. versionadded:: 5.12 """ - return ExpirationBasedAuthManager(provider) + handled_codes = frozenset(( + "Neo.ClientError.Security.TokenExpired", + "Neo.ClientError.Security.Unauthorized", + )) + return Neo4jAuthTokenManager(provider, handled_codes) diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index b8adff017..90dea5e17 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -48,7 +48,6 @@ ) from ..._exceptions import BoltError from ..._routing import RoutingTable -from ..._sync.auth_management import StaticAuthManager from ...api import ( READ_ACCESS, WRITE_ACCESS, @@ -62,11 +61,8 @@ ReadServiceUnavailable, ServiceUnavailable, SessionExpired, - TokenExpired, - TokenExpiredRetryable, WriteServiceUnavailable, ) -from ..auth_management import StaticAuthManager from ._bolt import Bolt @@ -464,15 +460,13 @@ def on_neo4j_error(self, error, connection): with self.lock: for connection in self.connections.get(address, ()): connection.mark_unauthenticated() - if error._requires_new_credentials(): - Util.callback( - connection.auth_manager.on_auth_expired, - connection.auth + if error._has_security_code(): + handled = Util.callback( + connection.auth_manager.handle_security_exception, + connection.auth, error ) - if (isinstance(error, TokenExpired) - and not isinstance(self.pool_config.auth, (StaticAuthManager, - StaticAuthManager))): - error.__class__ = TokenExpiredRetryable + if handled: + error._retryable = True def close(self): """ Close all connections and empty the pool. diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index d6708aee9..1c366fae8 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -145,6 +145,8 @@ class Neo4jError(Exception): #: (dict) Any additional information returned by the server. metadata = None + _retryable = False + @classmethod def hydrate( cls, @@ -230,14 +232,11 @@ def is_retryable(self) -> bool: .. versionadded:: 5.0 """ - return False + return self._retryable def _unauthenticates_all_connections(self) -> bool: return self.code == "Neo.ClientError.Security.AuthorizationExpired" - def _requires_new_credentials(self) -> bool: - return self.code == "Neo.ClientError.Security.TokenExpired" - # TODO: 6.0 - Remove this alias invalidates_all_connections = deprecated( "Neo4jError.invalidates_all_connections is deprecated and will be " @@ -263,6 +262,11 @@ def _is_fatal_during_discovery(self) -> bool: return True return False + def _has_security_code(self) -> bool: + if self.code is None: + return False + return self.code.startswith("Neo.ClientError.Security.") + # TODO: 6.0 - Remove this alias is_fatal_during_discovery = deprecated( "Neo4jError.is_fatal_during_discovery is deprecated and will be " @@ -319,19 +323,6 @@ class TokenExpired(AuthError): """ -# Neo4jError > ClientError > AuthError > TokenExpired > TokenExpiredRetryable -class TokenExpiredRetryable(TokenExpired): - """Raised when the authentication token has expired but can be refreshed. - - This is the same server error as :exc:`.TokenExpired`, but raised when - the driver is configured to be able to refresh the token, hence making - the error retryable. - """ - - def is_retryable(self) -> bool: - return True - - # Neo4jError > ClientError > Forbidden class Forbidden(ClientError): """ @@ -349,8 +340,7 @@ class TransientError(Neo4jError): """ The database cannot service the request right now, retrying later might yield a successful outcome. """ - def is_retryable(self) -> bool: - return True + _retryable = True # Neo4jError > TransientError > DatabaseUnavailable diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index b3bdd8855..89d355786 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -58,6 +58,7 @@ def __init__(self, rd, wr): self.auth_token_managers = {} self.auth_token_supplies = {} self.auth_token_on_expiration_supplies = {} + self.basic_auth_token_supplies = {} self.expiring_auth_token_supplies = {} self.bookmark_managers = {} self.bookmarks_consumptions = {} @@ -153,11 +154,13 @@ async def write_driver_exc(self, exc): payload["errorType"] = str(type(wrapped_exc)) if wrapped_exc.args: payload["msg"] = self._exc_msg(wrapped_exc.args[0]) + payload["retryable"] = False else: payload["errorType"] = str(type(exc)) payload["msg"] = self._exc_msg(exc) if isinstance(exc, Neo4jError): payload["code"] = exc.code + payload["retryable"] = getattr(exc, "is_retryable", bool)() await self.send_response("DriverError", payload) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 6711187f0..f272986f6 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -188,13 +188,14 @@ async def get_auth(self): ) return backend.auth_token_supplies.pop(key) - async def on_auth_expired(self, auth): + async def handle_security_exception(self, auth, error): key = backend.next_key() await backend.send_response( - "AuthTokenManagerOnAuthExpiredRequest", { + "AuthTokenManagerHandleSecurityExceptionRequest", { "id": key, "authTokenManagerId": auth_token_manager_id, "auth": totestkit.auth_token(auth), + "errorCode": error.code, } ) if not await backend.process_request(): @@ -203,10 +204,11 @@ async def on_auth_expired(self, auth): if key not in backend.auth_token_on_expiration_supplies: raise RuntimeError( "Backend did not receive expected " - "AuthTokenManagerOnAuthExpiredCompleted message for id " - f"{key}" + "AuthTokenManagerHandleSecurityExceptionCompleted message " + f"for id {key}" ) - backend.auth_token_on_expiration_supplies.pop(key) + handled = backend.auth_token_on_expiration_supplies.pop(key) + return handled auth_manager = TestKitAuthManager() backend.auth_token_managers[auth_token_manager_id] = auth_manager @@ -221,8 +223,9 @@ async def AuthTokenManagerGetAuthCompleted(backend, data): backend.auth_token_supplies[data["requestId"]] = auth_token -async def AuthTokenManagerOnAuthExpiredCompleted(backend, data): - backend.auth_token_on_expiration_supplies[data["requestId"]] = True +async def AuthTokenManagerHandleSecurityExceptionCompleted(backend, data): + handled = data["handled"] + backend.auth_token_on_expiration_supplies[data["requestId"]] = handled async def AuthTokenManagerClose(backend, data): @@ -233,16 +236,53 @@ async def AuthTokenManagerClose(backend, data): ) -async def NewExpirationBasedAuthTokenManager(backend, data): +async def NewBasicAuthTokenManager(backend, data): auth_token_manager_id = backend.next_key() async def auth_token_provider(): key = backend.next_key() await backend.send_response( - "ExpirationBasedAuthTokenProviderRequest", + "BasicAuthTokenProviderRequest", { "id": key, - "expirationBasedAuthTokenManagerId": auth_token_manager_id, + "basicAuthTokenManagerId": auth_token_manager_id, + } + ) + if not await backend.process_request(): + # connection was closed before end of next message + return None + if key not in backend.basic_auth_token_supplies: + raise RuntimeError( + "Backend did not receive expected " + "BasicAuthTokenManagerCompleted message for id " + f"{key}" + ) + return backend.basic_auth_token_supplies.pop(key) + + with warning_check(neo4j.PreviewWarning, + "Auth managers are a preview feature."): + auth_manager = AsyncAuthManagers.basic(auth_token_provider) + backend.auth_token_managers[auth_token_manager_id] = auth_manager + await backend.send_response( + "BasicAuthTokenManager", {"id": auth_token_manager_id} + ) + + +async def BasicAuthTokenProviderCompleted(backend, data): + auth = fromtestkit.to_auth_token(data, "auth") + backend.basic_auth_token_supplies[data["requestId"]] = auth + + +async def NewBearerAuthTokenManager(backend, data): + auth_token_manager_id = backend.next_key() + + async def auth_token_provider(): + key = backend.next_key() + await backend.send_response( + "BearerAuthTokenProviderRequest", + { + "id": key, + "bearerAuthTokenManagerId": auth_token_manager_id, } ) if not await backend.process_request(): @@ -251,21 +291,21 @@ async def auth_token_provider(): if key not in backend.expiring_auth_token_supplies: raise RuntimeError( "Backend did not receive expected " - "ExpirationBasedAuthTokenManagerCompleted message for id " + "BearerAuthTokenManagerCompleted message for id " f"{key}" ) return backend.expiring_auth_token_supplies.pop(key) with warning_check(neo4j.PreviewWarning, "Auth managers are a preview feature."): - auth_manager = AsyncAuthManagers.expiration_based(auth_token_provider) + auth_manager = AsyncAuthManagers.bearer(auth_token_provider) backend.auth_token_managers[auth_token_manager_id] = auth_manager await backend.send_response( - "ExpirationBasedAuthTokenManager", {"id": auth_token_manager_id} + "BearerAuthTokenManager", {"id": auth_token_manager_id} ) -async def ExpirationBasedAuthTokenProviderCompleted(backend, data): +async def BearerAuthTokenProviderCompleted(backend, data): temp_auth_data = data["auth"] temp_auth_data.mark_item_as_read_if_equals("name", "AuthTokenAndExpiration") diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index b5ad66ac0..a38aa1e0b 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -58,6 +58,7 @@ def __init__(self, rd, wr): self.auth_token_managers = {} self.auth_token_supplies = {} self.auth_token_on_expiration_supplies = {} + self.basic_auth_token_supplies = {} self.expiring_auth_token_supplies = {} self.bookmark_managers = {} self.bookmarks_consumptions = {} @@ -153,11 +154,13 @@ def write_driver_exc(self, exc): payload["errorType"] = str(type(wrapped_exc)) if wrapped_exc.args: payload["msg"] = self._exc_msg(wrapped_exc.args[0]) + payload["retryable"] = False else: payload["errorType"] = str(type(exc)) payload["msg"] = self._exc_msg(exc) if isinstance(exc, Neo4jError): payload["code"] = exc.code + payload["retryable"] = getattr(exc, "is_retryable", bool)() self.send_response("DriverError", payload) diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 2d22d9d85..7dd69cf7b 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -188,13 +188,14 @@ def get_auth(self): ) return backend.auth_token_supplies.pop(key) - def on_auth_expired(self, auth): + def handle_security_exception(self, auth, error): key = backend.next_key() backend.send_response( - "AuthTokenManagerOnAuthExpiredRequest", { + "AuthTokenManagerHandleSecurityExceptionRequest", { "id": key, "authTokenManagerId": auth_token_manager_id, "auth": totestkit.auth_token(auth), + "errorCode": error.code, } ) if not backend.process_request(): @@ -203,10 +204,11 @@ def on_auth_expired(self, auth): if key not in backend.auth_token_on_expiration_supplies: raise RuntimeError( "Backend did not receive expected " - "AuthTokenManagerOnAuthExpiredCompleted message for id " - f"{key}" + "AuthTokenManagerHandleSecurityExceptionCompleted message " + f"for id {key}" ) - backend.auth_token_on_expiration_supplies.pop(key) + handled = backend.auth_token_on_expiration_supplies.pop(key) + return handled auth_manager = TestKitAuthManager() backend.auth_token_managers[auth_token_manager_id] = auth_manager @@ -221,8 +223,9 @@ def AuthTokenManagerGetAuthCompleted(backend, data): backend.auth_token_supplies[data["requestId"]] = auth_token -def AuthTokenManagerOnAuthExpiredCompleted(backend, data): - backend.auth_token_on_expiration_supplies[data["requestId"]] = True +def AuthTokenManagerHandleSecurityExceptionCompleted(backend, data): + handled = data["handled"] + backend.auth_token_on_expiration_supplies[data["requestId"]] = handled def AuthTokenManagerClose(backend, data): @@ -233,16 +236,53 @@ def AuthTokenManagerClose(backend, data): ) -def NewExpirationBasedAuthTokenManager(backend, data): +def NewBasicAuthTokenManager(backend, data): auth_token_manager_id = backend.next_key() def auth_token_provider(): key = backend.next_key() backend.send_response( - "ExpirationBasedAuthTokenProviderRequest", + "BasicAuthTokenProviderRequest", { "id": key, - "expirationBasedAuthTokenManagerId": auth_token_manager_id, + "basicAuthTokenManagerId": auth_token_manager_id, + } + ) + if not backend.process_request(): + # connection was closed before end of next message + return None + if key not in backend.basic_auth_token_supplies: + raise RuntimeError( + "Backend did not receive expected " + "BasicAuthTokenManagerCompleted message for id " + f"{key}" + ) + return backend.basic_auth_token_supplies.pop(key) + + with warning_check(neo4j.PreviewWarning, + "Auth managers are a preview feature."): + auth_manager = AuthManagers.basic(auth_token_provider) + backend.auth_token_managers[auth_token_manager_id] = auth_manager + backend.send_response( + "BasicAuthTokenManager", {"id": auth_token_manager_id} + ) + + +def BasicAuthTokenProviderCompleted(backend, data): + auth = fromtestkit.to_auth_token(data, "auth") + backend.basic_auth_token_supplies[data["requestId"]] = auth + + +def NewBearerAuthTokenManager(backend, data): + auth_token_manager_id = backend.next_key() + + def auth_token_provider(): + key = backend.next_key() + backend.send_response( + "BearerAuthTokenProviderRequest", + { + "id": key, + "bearerAuthTokenManagerId": auth_token_manager_id, } ) if not backend.process_request(): @@ -251,21 +291,21 @@ def auth_token_provider(): if key not in backend.expiring_auth_token_supplies: raise RuntimeError( "Backend did not receive expected " - "ExpirationBasedAuthTokenManagerCompleted message for id " + "BearerAuthTokenManagerCompleted message for id " f"{key}" ) return backend.expiring_auth_token_supplies.pop(key) with warning_check(neo4j.PreviewWarning, "Auth managers are a preview feature."): - auth_manager = AuthManagers.expiration_based(auth_token_provider) + auth_manager = AuthManagers.bearer(auth_token_provider) backend.auth_token_managers[auth_token_manager_id] = auth_manager backend.send_response( - "ExpirationBasedAuthTokenManager", {"id": auth_token_manager_id} + "BearerAuthTokenManager", {"id": auth_token_manager_id} ) -def ExpirationBasedAuthTokenProviderCompleted(backend, data): +def BearerAuthTokenProviderCompleted(backend, data): temp_auth_data = data["auth"] temp_auth_data.mark_item_as_read_if_equals("name", "AuthTokenAndExpiration") diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index f4000e89a..fadd9579d 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -36,6 +36,7 @@ "Feature:API:Result.Peek": true, "Feature:API:Result.Single": true, "Feature:API:Result.SingleOptional": true, + "Feature:API:RetryableExceptions": true, "Feature:API:Session:AuthConfig": true, "Feature:API:Session:NotificationsConfig": true, "Feature:API:SSLConfig": true, diff --git a/tests/unit/async_/test_auth_manager.py b/tests/unit/async_/test_auth_manager.py index 65704204f..96f3ab69b 100644 --- a/tests/unit/async_/test_auth_manager.py +++ b/tests/unit/async_/test_auth_manager.py @@ -34,10 +34,13 @@ AsyncAuthManagers, ExpiringAuth, ) +from neo4j.exceptions import Neo4jError from ..._async_compat import mark_async_test +T = t.TypeVar("T") + SAMPLE_AUTHS = ( None, ("user", "password"), @@ -46,16 +49,45 @@ Auth("scheme", "principal", "credentials", "realm", para="meter"), ) +CODES_HANDLED_BY_BASIC_MANAGER = { + "Neo4j.ClientError.Security.Unauthorized", +} +CODES_HANDLED_BY_BEARER_MANAGER = { + "Neo4j.ClientError.Security.TokenExpired", + "Neo4j.ClientError.Security.Unauthorized", +} +SAMPLE_ERRORS = [ + Neo4jError.hydrate(code=code) for code in { + "Neo.ClientError.Security.AuthenticationRateLimit", + "Neo.ClientError.Security.AuthorizationExpired", + "Neo.ClientError.Security.CredentialsExpired", + "Neo.ClientError.Security.Forbidden", + "Neo.ClientError.Security.TokenExpired", + "Neo.ClientError.Security.Unauthorized", + "Neo.ClientError.Security.MadeUp", + "Neo.ClientError.Statement.SyntaxError", + *CODES_HANDLED_BY_BASIC_MANAGER, + *CODES_HANDLED_BY_BEARER_MANAGER, + } +] + @copy_signature(AsyncAuthManagers.static) def static_auth_manager(*args, **kwargs): with pytest.warns(PreviewWarning, match="Auth managers"): return AsyncAuthManagers.static(*args, **kwargs) -@copy_signature(AsyncAuthManagers.expiration_based) -def expiration_based_auth_manager(*args, **kwargs): + +@copy_signature(AsyncAuthManagers.basic) +def basic_auth_manager(*args, **kwargs): with pytest.warns(PreviewWarning, match="Auth managers"): - return AsyncAuthManagers.expiration_based(*args, **kwargs) + return AsyncAuthManagers.basic(*args, **kwargs) + + +@copy_signature(AsyncAuthManagers.bearer) +def bearer_auth_manager(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AsyncAuthManagers.bearer(*args, **kwargs) @copy_signature(ExpiringAuth) @@ -66,58 +98,119 @@ def expiring_auth(*args, **kwargs): @mark_async_test @pytest.mark.parametrize("auth", SAMPLE_AUTHS) +@pytest.mark.parametrize("error", SAMPLE_ERRORS) async def test_static_manager( - auth + auth: t.Union[t.Tuple[str, str], Auth, None], + error: Neo4jError ) -> None: manager: AsyncAuthManager = static_auth_manager(auth) assert await manager.get_auth() is auth - await manager.on_auth_expired(("something", "else")) + handled = await manager.handle_security_exception( + ("something", "else"), error + ) + assert handled is False assert await manager.get_auth() is auth - await manager.on_auth_expired(auth) + handled = await manager.handle_security_exception(auth, error) + assert handled is False assert await manager.get_auth() is auth @mark_async_test @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("error", SAMPLE_ERRORS) +async def test_basic_manager_manual_expiry( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + error: Neo4jError, + mocker +) -> None: + def return_value_generator(auth): + return auth + + await _test_manager( + auth1, auth2, return_value_generator, basic_auth_manager, error, + CODES_HANDLED_BY_BASIC_MANAGER, mocker + ) + + +@mark_async_test +@pytest.mark.parametrize(("auth1", "auth2"), + itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("error", SAMPLE_ERRORS) @pytest.mark.parametrize("expires_at", (None, .001, 1, 1000.)) -async def test_expiration_based_manager_manual_expiry( +async def test_bearer_manager_manual_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], + error: Neo4jError, expires_at: t.Optional[float], mocker ) -> None: + def return_value_generator(auth): + return expiring_auth(auth) + with freeze_time("1970-01-01 00:00:00") as frozen_time: assert isinstance(frozen_time, FrozenDateTimeFactory) - temporal_auth = expiring_auth(auth1, expires_at) - provider = mocker.AsyncMock(return_value=temporal_auth) - manager: AsyncAuthManager = expiration_based_auth_manager(provider) - - provider.assert_not_called() - assert await manager.get_auth() is auth1 - provider.assert_awaited_once() - provider.reset_mock() + await _test_manager( + auth1, auth2, return_value_generator, bearer_auth_manager, error, + CODES_HANDLED_BY_BEARER_MANAGER, mocker + ) - provider.return_value = expiring_auth(auth2) - await manager.on_auth_expired(("something", "else")) - assert await manager.get_auth() is auth1 +async def _test_manager( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + return_value_generator: t.Callable[ + [t.Union[t.Tuple[str, str], Auth, None]], T + ], + manager_factory: t.Callable[ + [t.Callable[[], t.Awaitable[T]]], AsyncAuthManager + ], + error: Neo4jError, + handled_codes: t.Container[str], + mocker: t.Any, +) -> None: + provider = mocker.AsyncMock(return_value=return_value_generator(auth1)) + typed_provider = t.cast(t.Callable[[], t.Awaitable[T]], provider) + manager: AsyncAuthManager = manager_factory(typed_provider) + provider.assert_not_called() + assert await manager.get_auth() is auth1 + provider.assert_awaited_once() + provider.reset_mock() + + provider.return_value = return_value_generator(auth2) + + handled = await manager.handle_security_exception( + ("something", "else"), error + ) + assert handled is False + assert await manager.get_auth() is auth1 + provider.assert_not_called() + + handled = await manager.handle_security_exception(auth1, error) + should_be_handled = error.code in handled_codes + if should_be_handled: + provider.assert_awaited_once() + assert handled is True + else: provider.assert_not_called() + assert handled is False + provider.reset_mock() - await manager.on_auth_expired(auth1) - provider.assert_awaited_once() - provider.reset_mock() + if should_be_handled: assert await manager.get_auth() is auth2 - provider.assert_not_called() + else: + assert await manager.get_auth() is auth1 + provider.assert_not_called() @mark_async_test @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) @pytest.mark.parametrize("expires_at", (None, -1, 1., 1, 1000.)) -async def test_expiration_based_manager_time_expiry( +async def test_bearer_manager_time_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], expires_at: t.Optional[float], @@ -130,7 +223,7 @@ async def test_expiration_based_manager_time_expiry( else: temporal_auth = expiring_auth(auth1) provider = mocker.AsyncMock(return_value=temporal_auth) - manager: AsyncAuthManager = expiration_based_auth_manager(provider) + manager: AsyncAuthManager = bearer_auth_manager(provider) provider.assert_not_called() assert await manager.get_auth() is auth1 diff --git a/tests/unit/sync/test_auth_manager.py b/tests/unit/sync/test_auth_manager.py index 634fc42a6..0476355b5 100644 --- a/tests/unit/sync/test_auth_manager.py +++ b/tests/unit/sync/test_auth_manager.py @@ -34,10 +34,13 @@ AuthManagers, ExpiringAuth, ) +from neo4j.exceptions import Neo4jError from ..._async_compat import mark_sync_test +T = t.TypeVar("T") + SAMPLE_AUTHS = ( None, ("user", "password"), @@ -46,16 +49,45 @@ Auth("scheme", "principal", "credentials", "realm", para="meter"), ) +CODES_HANDLED_BY_BASIC_MANAGER = { + "Neo4j.ClientError.Security.Unauthorized", +} +CODES_HANDLED_BY_BEARER_MANAGER = { + "Neo4j.ClientError.Security.TokenExpired", + "Neo4j.ClientError.Security.Unauthorized", +} +SAMPLE_ERRORS = [ + Neo4jError.hydrate(code=code) for code in { + "Neo.ClientError.Security.AuthenticationRateLimit", + "Neo.ClientError.Security.AuthorizationExpired", + "Neo.ClientError.Security.CredentialsExpired", + "Neo.ClientError.Security.Forbidden", + "Neo.ClientError.Security.TokenExpired", + "Neo.ClientError.Security.Unauthorized", + "Neo.ClientError.Security.MadeUp", + "Neo.ClientError.Statement.SyntaxError", + *CODES_HANDLED_BY_BASIC_MANAGER, + *CODES_HANDLED_BY_BEARER_MANAGER, + } +] + @copy_signature(AuthManagers.static) def static_auth_manager(*args, **kwargs): with pytest.warns(PreviewWarning, match="Auth managers"): return AuthManagers.static(*args, **kwargs) -@copy_signature(AuthManagers.expiration_based) -def expiration_based_auth_manager(*args, **kwargs): + +@copy_signature(AuthManagers.basic) +def basic_auth_manager(*args, **kwargs): with pytest.warns(PreviewWarning, match="Auth managers"): - return AuthManagers.expiration_based(*args, **kwargs) + return AuthManagers.basic(*args, **kwargs) + + +@copy_signature(AuthManagers.bearer) +def bearer_auth_manager(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AuthManagers.bearer(*args, **kwargs) @copy_signature(ExpiringAuth) @@ -66,58 +98,119 @@ def expiring_auth(*args, **kwargs): @mark_sync_test @pytest.mark.parametrize("auth", SAMPLE_AUTHS) +@pytest.mark.parametrize("error", SAMPLE_ERRORS) def test_static_manager( - auth + auth: t.Union[t.Tuple[str, str], Auth, None], + error: Neo4jError ) -> None: manager: AuthManager = static_auth_manager(auth) assert manager.get_auth() is auth - manager.on_auth_expired(("something", "else")) + handled = manager.handle_security_exception( + ("something", "else"), error + ) + assert handled is False assert manager.get_auth() is auth - manager.on_auth_expired(auth) + handled = manager.handle_security_exception(auth, error) + assert handled is False assert manager.get_auth() is auth @mark_sync_test @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("error", SAMPLE_ERRORS) +def test_basic_manager_manual_expiry( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + error: Neo4jError, + mocker +) -> None: + def return_value_generator(auth): + return auth + + _test_manager( + auth1, auth2, return_value_generator, basic_auth_manager, error, + CODES_HANDLED_BY_BASIC_MANAGER, mocker + ) + + +@mark_sync_test +@pytest.mark.parametrize(("auth1", "auth2"), + itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("error", SAMPLE_ERRORS) @pytest.mark.parametrize("expires_at", (None, .001, 1, 1000.)) -def test_expiration_based_manager_manual_expiry( +def test_bearer_manager_manual_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], + error: Neo4jError, expires_at: t.Optional[float], mocker ) -> None: + def return_value_generator(auth): + return expiring_auth(auth) + with freeze_time("1970-01-01 00:00:00") as frozen_time: assert isinstance(frozen_time, FrozenDateTimeFactory) - temporal_auth = expiring_auth(auth1, expires_at) - provider = mocker.MagicMock(return_value=temporal_auth) - manager: AuthManager = expiration_based_auth_manager(provider) - - provider.assert_not_called() - assert manager.get_auth() is auth1 - provider.assert_called_once() - provider.reset_mock() + _test_manager( + auth1, auth2, return_value_generator, bearer_auth_manager, error, + CODES_HANDLED_BY_BEARER_MANAGER, mocker + ) - provider.return_value = expiring_auth(auth2) - manager.on_auth_expired(("something", "else")) - assert manager.get_auth() is auth1 +def _test_manager( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + return_value_generator: t.Callable[ + [t.Union[t.Tuple[str, str], Auth, None]], T + ], + manager_factory: t.Callable[ + [t.Callable[[], t.Union[T]]], AuthManager + ], + error: Neo4jError, + handled_codes: t.Container[str], + mocker: t.Any, +) -> None: + provider = mocker.MagicMock(return_value=return_value_generator(auth1)) + typed_provider = t.cast(t.Callable[[], t.Union[T]], provider) + manager: AuthManager = manager_factory(typed_provider) + provider.assert_not_called() + assert manager.get_auth() is auth1 + provider.assert_called_once() + provider.reset_mock() + + provider.return_value = return_value_generator(auth2) + + handled = manager.handle_security_exception( + ("something", "else"), error + ) + assert handled is False + assert manager.get_auth() is auth1 + provider.assert_not_called() + + handled = manager.handle_security_exception(auth1, error) + should_be_handled = error.code in handled_codes + if should_be_handled: + provider.assert_called_once() + assert handled is True + else: provider.assert_not_called() + assert handled is False + provider.reset_mock() - manager.on_auth_expired(auth1) - provider.assert_called_once() - provider.reset_mock() + if should_be_handled: assert manager.get_auth() is auth2 - provider.assert_not_called() + else: + assert manager.get_auth() is auth1 + provider.assert_not_called() @mark_sync_test @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) @pytest.mark.parametrize("expires_at", (None, -1, 1., 1, 1000.)) -def test_expiration_based_manager_time_expiry( +def test_bearer_manager_time_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], expires_at: t.Optional[float], @@ -130,7 +223,7 @@ def test_expiration_based_manager_time_expiry( else: temporal_auth = expiring_auth(auth1) provider = mocker.MagicMock(return_value=temporal_auth) - manager: AuthManager = expiration_based_auth_manager(provider) + manager: AuthManager = bearer_auth_manager(provider) provider.assert_not_called() assert manager.get_auth() is auth1 From 4367281fd152e13eec3e13e43ccc9153d5711b66 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 14 Aug 2023 10:46:31 +0200 Subject: [PATCH 2/3] Adjust unit tests --- tests/unit/async_/io/test_neo4j_pool.py | 31 ++++---- tests/unit/async_/test_auth_manager.py | 94 ++++++++++++------------- tests/unit/sync/io/test_neo4j_pool.py | 31 ++++---- tests/unit/sync/test_auth_manager.py | 94 ++++++++++++------------- 4 files changed, 126 insertions(+), 124 deletions(-) diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index ca27d92f3..79cd98b8a 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -587,25 +587,24 @@ async def test_fast_failing_discovery(routing_failure_opener, error): assert len(opener.connections) == 3 - @pytest.mark.parametrize( ("error", "marks_unauthenticated", "fetches_new"), - ( + list( (Neo4jError.hydrate("message", args[0]), *args[1:]) for args in ( ("Neo.ClientError.Database.DatabaseNotFound", False, False), ("Neo.ClientError.Statement.TypeError", False, False), ("Neo.ClientError.Statement.ArgumentError", False, False), ("Neo.ClientError.Request.Invalid", False, False), - ("Neo.ClientError.Security.AuthenticationRateLimit", False, False), - ("Neo.ClientError.Security.CredentialsExpired", False, False), - ("Neo.ClientError.Security.Forbidden", False, False), - ("Neo.ClientError.Security.Unauthorized", False, False), - ("Neo.ClientError.Security.MadeUpError", False, False), + ("Neo.ClientError.Security.AuthenticationRateLimit", False, True), + ("Neo.ClientError.Security.CredentialsExpired", False, True), + ("Neo.ClientError.Security.Forbidden", False, True), + ("Neo.ClientError.Security.Unauthorized", False, True), + ("Neo.ClientError.Security.MadeUpError", False, True), ("Neo.ClientError.Security.TokenExpired", False, True), - ("Neo.ClientError.Security.AuthorizationExpired", True, False), + ("Neo.ClientError.Security.AuthorizationExpired", True, True), ) - ) + )[4:5] ) @mark_async_test async def test_connection_error_callback( @@ -613,8 +612,9 @@ async def test_connection_error_callback( ): config = _pool_config() auth_manager = _auth_manager(("user", "auth")) - on_auth_expired_mock = mocker.patch.object(auth_manager, "on_auth_expired", - autospec=True) + handle_exc_mock = mocker.patch.object( + auth_manager, "handle_security_exception", autospec=True + ) config.auth = auth_manager pool = AsyncNeo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS @@ -628,18 +628,19 @@ async def test_connection_error_callback( for _ in range(5) ] - on_auth_expired_mock.assert_not_called() + handle_exc_mock.assert_not_called() for cx in cxs_read + cxs_write: cx.mark_unauthenticated.assert_not_called() await pool.on_neo4j_error(error, cxs_read[0]) if fetches_new: - cxs_read[0].auth_manager.on_auth_expired.assert_awaited_once() + cx = cxs_read[0] + cx.auth_manager.handle_security_exception.assert_awaited_once() else: - on_auth_expired_mock.assert_not_called() + handle_exc_mock.assert_not_called() for cx in cxs_read: - cx.auth_manager.on_auth_expired.assert_not_called() + cx.auth_manager.handle_security_exception.assert_not_called() for cx in cxs_read: if marks_unauthenticated: diff --git a/tests/unit/async_/test_auth_manager.py b/tests/unit/async_/test_auth_manager.py index 96f3ab69b..6a1fdbcb6 100644 --- a/tests/unit/async_/test_auth_manager.py +++ b/tests/unit/async_/test_auth_manager.py @@ -50,11 +50,11 @@ ) CODES_HANDLED_BY_BASIC_MANAGER = { - "Neo4j.ClientError.Security.Unauthorized", + "Neo.ClientError.Security.Unauthorized", } CODES_HANDLED_BY_BEARER_MANAGER = { - "Neo4j.ClientError.Security.TokenExpired", - "Neo4j.ClientError.Security.Unauthorized", + "Neo.ClientError.Security.TokenExpired", + "Neo.ClientError.Security.Unauthorized", } SAMPLE_ERRORS = [ Neo4jError.hydrate(code=code) for code in { @@ -119,7 +119,7 @@ async def test_static_manager( @mark_async_test @pytest.mark.parametrize(("auth1", "auth2"), - itertools.product(SAMPLE_AUTHS, repeat=2)) + list(itertools.product(SAMPLE_AUTHS, repeat=2))) @pytest.mark.parametrize("error", SAMPLE_ERRORS) async def test_basic_manager_manual_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], @@ -159,6 +159,45 @@ def return_value_generator(auth): ) +@mark_async_test +@pytest.mark.parametrize(("auth1", "auth2"), + itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("expires_at", (None, -1, 1., 1, 1000.)) +async def test_bearer_manager_time_expiry( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + expires_at: t.Optional[float], + mocker +) -> None: + with freeze_time("1970-01-01 00:00:00") as frozen_time: + assert isinstance(frozen_time, FrozenDateTimeFactory) + if expires_at is None or expires_at >= 0: + temporal_auth = expiring_auth(auth1, expires_at) + else: + temporal_auth = expiring_auth(auth1) + provider = mocker.AsyncMock(return_value=temporal_auth) + manager: AsyncAuthManager = bearer_auth_manager(provider) + + provider.assert_not_called() + assert await manager.get_auth() is auth1 + provider.assert_awaited_once() + provider.reset_mock() + + provider.return_value = expiring_auth(auth2) + + if expires_at is None or expires_at < 0: + frozen_time.tick(1_000_000) + assert await manager.get_auth() is auth1 + provider.assert_not_called() + else: + frozen_time.tick(expires_at - 0.000001) + assert await manager.get_auth() is auth1 + provider.assert_not_called() + frozen_time.tick(0.000002) + assert await manager.get_auth() is auth2 + provider.assert_awaited_once() + + async def _test_manager( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], @@ -182,21 +221,21 @@ async def _test_manager( provider.return_value = return_value_generator(auth2) + should_be_handled = error.code in handled_codes handled = await manager.handle_security_exception( ("something", "else"), error ) - assert handled is False + assert handled is should_be_handled assert await manager.get_auth() is auth1 provider.assert_not_called() handled = await manager.handle_security_exception(auth1, error) - should_be_handled = error.code in handled_codes + if should_be_handled: provider.assert_awaited_once() - assert handled is True else: provider.assert_not_called() - assert handled is False + assert handled is should_be_handled provider.reset_mock() if should_be_handled: @@ -204,42 +243,3 @@ async def _test_manager( else: assert await manager.get_auth() is auth1 provider.assert_not_called() - - -@mark_async_test -@pytest.mark.parametrize(("auth1", "auth2"), - itertools.product(SAMPLE_AUTHS, repeat=2)) -@pytest.mark.parametrize("expires_at", (None, -1, 1., 1, 1000.)) -async def test_bearer_manager_time_expiry( - auth1: t.Union[t.Tuple[str, str], Auth, None], - auth2: t.Union[t.Tuple[str, str], Auth, None], - expires_at: t.Optional[float], - mocker -) -> None: - with freeze_time("1970-01-01 00:00:00") as frozen_time: - assert isinstance(frozen_time, FrozenDateTimeFactory) - if expires_at is None or expires_at >= 0: - temporal_auth = expiring_auth(auth1, expires_at) - else: - temporal_auth = expiring_auth(auth1) - provider = mocker.AsyncMock(return_value=temporal_auth) - manager: AsyncAuthManager = bearer_auth_manager(provider) - - provider.assert_not_called() - assert await manager.get_auth() is auth1 - provider.assert_awaited_once() - provider.reset_mock() - - provider.return_value = expiring_auth(auth2) - - if expires_at is None or expires_at < 0: - frozen_time.tick(1_000_000) - assert await manager.get_auth() is auth1 - provider.assert_not_called() - else: - frozen_time.tick(expires_at - 0.000001) - assert await manager.get_auth() is auth1 - provider.assert_not_called() - frozen_time.tick(0.000002) - assert await manager.get_auth() is auth2 - provider.assert_awaited_once() diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index cfaaf1f34..c126da3ec 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -587,25 +587,24 @@ def test_fast_failing_discovery(routing_failure_opener, error): assert len(opener.connections) == 3 - @pytest.mark.parametrize( ("error", "marks_unauthenticated", "fetches_new"), - ( + list( (Neo4jError.hydrate("message", args[0]), *args[1:]) for args in ( ("Neo.ClientError.Database.DatabaseNotFound", False, False), ("Neo.ClientError.Statement.TypeError", False, False), ("Neo.ClientError.Statement.ArgumentError", False, False), ("Neo.ClientError.Request.Invalid", False, False), - ("Neo.ClientError.Security.AuthenticationRateLimit", False, False), - ("Neo.ClientError.Security.CredentialsExpired", False, False), - ("Neo.ClientError.Security.Forbidden", False, False), - ("Neo.ClientError.Security.Unauthorized", False, False), - ("Neo.ClientError.Security.MadeUpError", False, False), + ("Neo.ClientError.Security.AuthenticationRateLimit", False, True), + ("Neo.ClientError.Security.CredentialsExpired", False, True), + ("Neo.ClientError.Security.Forbidden", False, True), + ("Neo.ClientError.Security.Unauthorized", False, True), + ("Neo.ClientError.Security.MadeUpError", False, True), ("Neo.ClientError.Security.TokenExpired", False, True), - ("Neo.ClientError.Security.AuthorizationExpired", True, False), + ("Neo.ClientError.Security.AuthorizationExpired", True, True), ) - ) + )[4:5] ) @mark_sync_test def test_connection_error_callback( @@ -613,8 +612,9 @@ def test_connection_error_callback( ): config = _pool_config() auth_manager = _auth_manager(("user", "auth")) - on_auth_expired_mock = mocker.patch.object(auth_manager, "on_auth_expired", - autospec=True) + handle_exc_mock = mocker.patch.object( + auth_manager, "handle_security_exception", autospec=True + ) config.auth = auth_manager pool = Neo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS @@ -628,18 +628,19 @@ def test_connection_error_callback( for _ in range(5) ] - on_auth_expired_mock.assert_not_called() + handle_exc_mock.assert_not_called() for cx in cxs_read + cxs_write: cx.mark_unauthenticated.assert_not_called() pool.on_neo4j_error(error, cxs_read[0]) if fetches_new: - cxs_read[0].auth_manager.on_auth_expired.assert_called_once() + cx = cxs_read[0] + cx.auth_manager.handle_security_exception.assert_called_once() else: - on_auth_expired_mock.assert_not_called() + handle_exc_mock.assert_not_called() for cx in cxs_read: - cx.auth_manager.on_auth_expired.assert_not_called() + cx.auth_manager.handle_security_exception.assert_not_called() for cx in cxs_read: if marks_unauthenticated: diff --git a/tests/unit/sync/test_auth_manager.py b/tests/unit/sync/test_auth_manager.py index 0476355b5..a96d76c46 100644 --- a/tests/unit/sync/test_auth_manager.py +++ b/tests/unit/sync/test_auth_manager.py @@ -50,11 +50,11 @@ ) CODES_HANDLED_BY_BASIC_MANAGER = { - "Neo4j.ClientError.Security.Unauthorized", + "Neo.ClientError.Security.Unauthorized", } CODES_HANDLED_BY_BEARER_MANAGER = { - "Neo4j.ClientError.Security.TokenExpired", - "Neo4j.ClientError.Security.Unauthorized", + "Neo.ClientError.Security.TokenExpired", + "Neo.ClientError.Security.Unauthorized", } SAMPLE_ERRORS = [ Neo4jError.hydrate(code=code) for code in { @@ -119,7 +119,7 @@ def test_static_manager( @mark_sync_test @pytest.mark.parametrize(("auth1", "auth2"), - itertools.product(SAMPLE_AUTHS, repeat=2)) + list(itertools.product(SAMPLE_AUTHS, repeat=2))) @pytest.mark.parametrize("error", SAMPLE_ERRORS) def test_basic_manager_manual_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], @@ -159,6 +159,45 @@ def return_value_generator(auth): ) +@mark_sync_test +@pytest.mark.parametrize(("auth1", "auth2"), + itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("expires_at", (None, -1, 1., 1, 1000.)) +def test_bearer_manager_time_expiry( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + expires_at: t.Optional[float], + mocker +) -> None: + with freeze_time("1970-01-01 00:00:00") as frozen_time: + assert isinstance(frozen_time, FrozenDateTimeFactory) + if expires_at is None or expires_at >= 0: + temporal_auth = expiring_auth(auth1, expires_at) + else: + temporal_auth = expiring_auth(auth1) + provider = mocker.MagicMock(return_value=temporal_auth) + manager: AuthManager = bearer_auth_manager(provider) + + provider.assert_not_called() + assert manager.get_auth() is auth1 + provider.assert_called_once() + provider.reset_mock() + + provider.return_value = expiring_auth(auth2) + + if expires_at is None or expires_at < 0: + frozen_time.tick(1_000_000) + assert manager.get_auth() is auth1 + provider.assert_not_called() + else: + frozen_time.tick(expires_at - 0.000001) + assert manager.get_auth() is auth1 + provider.assert_not_called() + frozen_time.tick(0.000002) + assert manager.get_auth() is auth2 + provider.assert_called_once() + + def _test_manager( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], @@ -182,21 +221,21 @@ def _test_manager( provider.return_value = return_value_generator(auth2) + should_be_handled = error.code in handled_codes handled = manager.handle_security_exception( ("something", "else"), error ) - assert handled is False + assert handled is should_be_handled assert manager.get_auth() is auth1 provider.assert_not_called() handled = manager.handle_security_exception(auth1, error) - should_be_handled = error.code in handled_codes + if should_be_handled: provider.assert_called_once() - assert handled is True else: provider.assert_not_called() - assert handled is False + assert handled is should_be_handled provider.reset_mock() if should_be_handled: @@ -204,42 +243,3 @@ def _test_manager( else: assert manager.get_auth() is auth1 provider.assert_not_called() - - -@mark_sync_test -@pytest.mark.parametrize(("auth1", "auth2"), - itertools.product(SAMPLE_AUTHS, repeat=2)) -@pytest.mark.parametrize("expires_at", (None, -1, 1., 1, 1000.)) -def test_bearer_manager_time_expiry( - auth1: t.Union[t.Tuple[str, str], Auth, None], - auth2: t.Union[t.Tuple[str, str], Auth, None], - expires_at: t.Optional[float], - mocker -) -> None: - with freeze_time("1970-01-01 00:00:00") as frozen_time: - assert isinstance(frozen_time, FrozenDateTimeFactory) - if expires_at is None or expires_at >= 0: - temporal_auth = expiring_auth(auth1, expires_at) - else: - temporal_auth = expiring_auth(auth1) - provider = mocker.MagicMock(return_value=temporal_auth) - manager: AuthManager = bearer_auth_manager(provider) - - provider.assert_not_called() - assert manager.get_auth() is auth1 - provider.assert_called_once() - provider.reset_mock() - - provider.return_value = expiring_auth(auth2) - - if expires_at is None or expires_at < 0: - frozen_time.tick(1_000_000) - assert manager.get_auth() is auth1 - provider.assert_not_called() - else: - frozen_time.tick(expires_at - 0.000001) - assert manager.get_auth() is auth1 - provider.assert_not_called() - frozen_time.tick(0.000002) - assert manager.get_auth() is auth2 - provider.assert_called_once() From 50acf223cc3a1a54d31c0409ca7bd84cb7f273d5 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 22 Aug 2023 12:21:40 +0200 Subject: [PATCH 3/3] Fix async class naming --- src/neo4j/_async/auth_management.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py index a6d8d2861..8519601ce 100644 --- a/src/neo4j/_async/auth_management.py +++ b/src/neo4j/_async/auth_management.py @@ -61,7 +61,7 @@ async def handle_security_exception( return False -class Neo4jAuthTokenManager(AsyncAuthManager): +class AsyncNeo4jAuthTokenManager(AsyncAuthManager): _current_auth: t.Optional[ExpiringAuth] _provider: t.Callable[[], t.Awaitable[ExpiringAuth]] _handled_codes: t.FrozenSet[str] @@ -217,7 +217,7 @@ async def wrapped_provider() -> ExpiringAuth: category=PreviewWarning) return ExpiringAuth(await provider()) - return Neo4jAuthTokenManager(wrapped_provider, handled_codes) + return AsyncNeo4jAuthTokenManager(wrapped_provider, handled_codes) @staticmethod @preview("Auth managers are a preview feature.") @@ -282,4 +282,4 @@ async def auth_provider(): "Neo.ClientError.Security.TokenExpired", "Neo.ClientError.Security.Unauthorized", )) - return Neo4jAuthTokenManager(provider, handled_codes) + return AsyncNeo4jAuthTokenManager(provider, handled_codes)