diff --git a/tests/conftest.py b/tests/conftest.py index 8795c9f022..e5eea4d582 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,7 @@ import random import time from datetime import datetime, timezone -from enum import Enum -from typing import Callable, TypeVar, Union +from typing import Callable, TypeVar from unittest import mock from unittest.mock import Mock from urllib.parse import urlparse @@ -16,7 +15,6 @@ from redis import Sentinel from redis.auth.idp import IdentityProviderInterface from redis.auth.token import JWToken -from redis.auth.token_manager import RetryPolicy, TokenManagerConfig from redis.backoff import NoBackoff from redis.cache import ( CacheConfig, @@ -29,22 +27,6 @@ from redis.credentials import CredentialProvider from redis.exceptions import RedisClusterException from redis.retry import Retry -from redis_entraid.cred_provider import ( - DEFAULT_DELAY_IN_MS, - DEFAULT_EXPIRATION_REFRESH_RATIO, - DEFAULT_LOWER_REFRESH_BOUND_MILLIS, - DEFAULT_MAX_ATTEMPTS, - DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, - EntraIdCredentialsProvider, -) -from redis_entraid.identity_provider import ( - ManagedIdentityIdType, - ManagedIdentityProviderConfig, - ManagedIdentityType, - ServicePrincipalIdentityProviderConfig, - _create_provider_from_managed_identity, - _create_provider_from_service_principal, -) from tests.ssl_utils import get_tls_certificates REDIS_INFO = {} @@ -60,11 +42,6 @@ _TestDecorator = Callable[[_DecoratedTest], _DecoratedTest] -class AuthType(Enum): - MANAGED_IDENTITY = "managed_identity" - SERVICE_PRINCIPAL = "service_principal" - - # Taken from python3.9 class BooleanOptionalAction(argparse.Action): def __init__( @@ -623,124 +600,18 @@ def mock_identity_provider() -> IdentityProviderInterface: return mock_provider -def identity_provider(request) -> IdentityProviderInterface: - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - else: - kwargs = {} - - if request.param.get("mock_idp", None) is not None: - return mock_identity_provider() - - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) - config = get_identity_provider_config(request=request) - - if auth_type == "MANAGED_IDENTITY": - return _create_provider_from_managed_identity(config) - - return _create_provider_from_service_principal(config) - - -def get_identity_provider_config( - request, -) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - else: - kwargs = {} - - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) - - if auth_type == AuthType.MANAGED_IDENTITY: - return _get_managed_identity_provider_config(request) - - return _get_service_principal_provider_config(request) - - -def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: - resource = os.getenv("AZURE_RESOURCE") - id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) - - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - else: - kwargs = {} - - identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) - id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - - return ManagedIdentityProviderConfig( - identity_type=identity_type, - resource=resource, - id_type=id_type, - id_value=id_value, - kwargs=kwargs, - ) - - -def _get_service_principal_provider_config( - request, -) -> ServicePrincipalIdentityProviderConfig: - client_id = os.getenv("AZURE_CLIENT_ID") - client_credential = os.getenv("AZURE_CLIENT_SECRET") - tenant_id = os.getenv("AZURE_TENANT_ID") - scopes = os.getenv("AZURE_REDIS_SCOPES", None) - - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - token_kwargs = request.param.get("token_kwargs", {}) - timeout = request.param.get("timeout", None) - else: - kwargs = {} - token_kwargs = {} - timeout = None - - if isinstance(scopes, str): - scopes = scopes.split(",") - - return ServicePrincipalIdentityProviderConfig( - client_id=client_id, - client_credential=client_credential, - scopes=scopes, - timeout=timeout, - token_kwargs=token_kwargs, - tenant_id=tenant_id, - app_kwargs=kwargs, - ) - - def get_credential_provider(request) -> CredentialProvider: cred_provider_class = request.param.get("cred_provider_class") cred_provider_kwargs = request.param.get("cred_provider_kwargs", {}) - if cred_provider_class != EntraIdCredentialsProvider: + # Since we can't import EntraIdCredentialsProvider in this module, + # we'll just check the class name. + if cred_provider_class.__name__ != "EntraIdCredentialsProvider": return cred_provider_class(**cred_provider_kwargs) - idp = identity_provider(request) - expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO - ) - lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS - ) - max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) - delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) - - token_mgr_config = TokenManagerConfig( - expiration_refresh_ratio=expiration_refresh_ratio, - lower_refresh_bound_millis=lower_refresh_bound_millis, - token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa - retry_policy=RetryPolicy( - max_attempts=max_attempts, - delay_in_ms=delay_in_ms, - ), - ) + from tests.entraid_utils import get_entra_id_credentials_provider - return EntraIdCredentialsProvider( - identity_provider=idp, - token_manager_config=token_mgr_config, - initial_delay_in_ms=delay_in_ms, - ) + return get_entra_id_credentials_provider(request, cred_provider_kwargs) @pytest.fixture() diff --git a/tests/entraid_utils.py b/tests/entraid_utils.py new file mode 100644 index 0000000000..daefbd3956 --- /dev/null +++ b/tests/entraid_utils.py @@ -0,0 +1,140 @@ +import os +from enum import Enum +from typing import Union + +from redis.auth.idp import IdentityProviderInterface +from redis.auth.token_manager import RetryPolicy, TokenManagerConfig +from redis_entraid.cred_provider import ( + DEFAULT_DELAY_IN_MS, + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_MAX_ATTEMPTS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + EntraIdCredentialsProvider, +) +from redis_entraid.identity_provider import ( + ManagedIdentityIdType, + ManagedIdentityProviderConfig, + ManagedIdentityType, + ServicePrincipalIdentityProviderConfig, + _create_provider_from_managed_identity, + _create_provider_from_service_principal, +) +from tests.conftest import mock_identity_provider + + +class AuthType(Enum): + MANAGED_IDENTITY = "managed_identity" + SERVICE_PRINCIPAL = "service_principal" + + +def identity_provider(request) -> IdentityProviderInterface: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + if request.param.get("mock_idp", None) is not None: + return mock_identity_provider() + + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + config = get_identity_provider_config(request=request) + + if auth_type == "MANAGED_IDENTITY": + return _create_provider_from_managed_identity(config) + + return _create_provider_from_service_principal(config) + + +def get_identity_provider_config( + request, +) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + + if auth_type == AuthType.MANAGED_IDENTITY: + return _get_managed_identity_provider_config(request) + + return _get_service_principal_provider_config(request) + + +def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: + resource = os.getenv("AZURE_RESOURCE") + id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) + + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) + id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) + + return ManagedIdentityProviderConfig( + identity_type=identity_type, + resource=resource, + id_type=id_type, + id_value=id_value, + kwargs=kwargs, + ) + + +def _get_service_principal_provider_config( + request, +) -> ServicePrincipalIdentityProviderConfig: + client_id = os.getenv("AZURE_CLIENT_ID") + client_credential = os.getenv("AZURE_CLIENT_SECRET") + tenant_id = os.getenv("AZURE_TENANT_ID") + scopes = os.getenv("AZURE_REDIS_SCOPES", None) + + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + token_kwargs = request.param.get("token_kwargs", {}) + timeout = request.param.get("timeout", None) + else: + kwargs = {} + token_kwargs = {} + timeout = None + + if isinstance(scopes, str): + scopes = scopes.split(",") + + return ServicePrincipalIdentityProviderConfig( + client_id=client_id, + client_credential=client_credential, + scopes=scopes, + timeout=timeout, + token_kwargs=token_kwargs, + tenant_id=tenant_id, + app_kwargs=kwargs, + ) + + +def get_entra_id_credentials_provider(request, cred_provider_kwargs): + idp = identity_provider(request) + expiration_refresh_ratio = cred_provider_kwargs.get( + "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO + ) + lower_refresh_bound_millis = cred_provider_kwargs.get( + "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS + ) + max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) + delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) + token_mgr_config = TokenManagerConfig( + expiration_refresh_ratio=expiration_refresh_ratio, + lower_refresh_bound_millis=lower_refresh_bound_millis, + token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa + retry_policy=RetryPolicy( + max_attempts=max_attempts, + delay_in_ms=delay_in_ms, + ), + ) + return EntraIdCredentialsProvider( + identity_provider=idp, + token_manager_config=token_mgr_config, + initial_delay_in_ms=delay_in_ms, + ) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index d9cccf1b92..226b00aa45 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,41 +1,19 @@ -import os import random from contextlib import asynccontextmanager as _asynccontextmanager -from datetime import datetime, timezone from enum import Enum from typing import Union import pytest import pytest_asyncio import redis.asyncio as redis -from mock.mock import Mock from packaging.version import Version from redis.asyncio import Sentinel from redis.asyncio.client import Monitor from redis.asyncio.connection import Connection, parse_url from redis.asyncio.retry import Retry -from redis.auth.idp import IdentityProviderInterface -from redis.auth.token import JWToken -from redis.auth.token_manager import RetryPolicy, TokenManagerConfig from redis.backoff import NoBackoff from redis.credentials import CredentialProvider -from redis_entraid.cred_provider import ( - DEFAULT_DELAY_IN_MS, - DEFAULT_EXPIRATION_REFRESH_RATIO, - DEFAULT_LOWER_REFRESH_BOUND_MILLIS, - DEFAULT_MAX_ATTEMPTS, - DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, - EntraIdCredentialsProvider, -) -from redis_entraid.identity_provider import ( - ManagedIdentityIdType, - ManagedIdentityProviderConfig, - ManagedIdentityType, - ServicePrincipalIdentityProviderConfig, - _create_provider_from_managed_identity, - _create_provider_from_service_principal, -) -from tests.conftest import REDIS_INFO +from tests.conftest import REDIS_INFO, get_credential_provider from .compat import mock @@ -247,136 +225,6 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): yield mocked -def mock_identity_provider() -> IdentityProviderInterface: - jwt = pytest.importorskip("jwt") - mock_provider = Mock(spec=IdentityProviderInterface) - token = {"exp": datetime.now(timezone.utc).timestamp() + 3600, "oid": "username"} - encoded = jwt.encode(token, "secret", algorithm="HS256") - jwt_token = JWToken(encoded) - mock_provider.request_token.return_value = jwt_token - return mock_provider - - -def identity_provider(request) -> IdentityProviderInterface: - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - else: - kwargs = {} - - if request.param.get("mock_idp", None) is not None: - return mock_identity_provider() - - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) - config = get_identity_provider_config(request=request) - - if auth_type == "MANAGED_IDENTITY": - return _create_provider_from_managed_identity(config) - - return _create_provider_from_service_principal(config) - - -def get_identity_provider_config( - request, -) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - else: - kwargs = {} - - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) - - if auth_type == AuthType.MANAGED_IDENTITY: - return _get_managed_identity_provider_config(request) - - return _get_service_principal_provider_config(request) - - -def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: - resource = os.getenv("AZURE_RESOURCE") - id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) - - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - else: - kwargs = {} - - identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) - id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - - return ManagedIdentityProviderConfig( - identity_type=identity_type, - resource=resource, - id_type=id_type, - id_value=id_value, - kwargs=kwargs, - ) - - -def _get_service_principal_provider_config( - request, -) -> ServicePrincipalIdentityProviderConfig: - client_id = os.getenv("AZURE_CLIENT_ID") - client_credential = os.getenv("AZURE_CLIENT_SECRET") - tenant_id = os.getenv("AZURE_TENANT_ID") - scopes = os.getenv("AZURE_REDIS_SCOPES", None) - - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - token_kwargs = request.param.get("token_kwargs", {}) - timeout = request.param.get("timeout", None) - else: - kwargs = {} - token_kwargs = {} - timeout = None - - if isinstance(scopes, str): - scopes = scopes.split(",") - - return ServicePrincipalIdentityProviderConfig( - client_id=client_id, - client_credential=client_credential, - scopes=scopes, - timeout=timeout, - token_kwargs=token_kwargs, - tenant_id=tenant_id, - app_kwargs=kwargs, - ) - - -def get_credential_provider(request) -> CredentialProvider: - cred_provider_class = request.param.get("cred_provider_class") - cred_provider_kwargs = request.param.get("cred_provider_kwargs", {}) - - if cred_provider_class != EntraIdCredentialsProvider: - return cred_provider_class(**cred_provider_kwargs) - - idp = identity_provider(request) - expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO - ) - lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS - ) - max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) - delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) - - token_mgr_config = TokenManagerConfig( - expiration_refresh_ratio=expiration_refresh_ratio, - lower_refresh_bound_millis=lower_refresh_bound_millis, - token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa - retry_policy=RetryPolicy( - max_attempts=max_attempts, - delay_in_ms=delay_in_ms, - ), - ) - - return EntraIdCredentialsProvider( - identity_provider=idp, - token_manager_config=token_mgr_config, - initial_delay_in_ms=delay_in_ms, - ) - - @pytest_asyncio.fixture() async def credential_provider(request) -> CredentialProvider: return get_credential_provider(request) diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index 1eb988ce71..ce8d76ea45 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -17,10 +17,14 @@ from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ConnectionError from redis.utils import str_if_bytes -from redis_entraid.cred_provider import EntraIdCredentialsProvider from tests.conftest import get_endpoint, skip_if_redis_enterprise from tests.test_asyncio.conftest import get_credential_provider +try: + from redis_entraid.cred_provider import EntraIdCredentialsProvider +except ImportError: + EntraIdCredentialsProvider = None + @pytest.fixture() def endpoint(request): @@ -321,6 +325,7 @@ async def test_user_pass_provider_only_password( @pytest.mark.asyncio @pytest.mark.onlynoncluster +@pytest.mark.skipif(not EntraIdCredentialsProvider, reason="requires redis-entraid") class TestStreamingCredentialProvider: @pytest.mark.parametrize( "credential_provider", @@ -599,6 +604,7 @@ async def test_fails_on_token_renewal(self, credential_provider): @pytest.mark.asyncio @pytest.mark.onlynoncluster @pytest.mark.cp_integration +@pytest.mark.skipif(not EntraIdCredentialsProvider, reason="requires redis-entraid") class TestEntraIdCredentialsProvider: @pytest.mark.parametrize( "r_credential", @@ -674,6 +680,7 @@ async def test_async_auth_pubsub_with_credential_provider( @pytest.mark.asyncio @pytest.mark.onlycluster @pytest.mark.cp_integration +@pytest.mark.skipif(not EntraIdCredentialsProvider, reason="requires redis-entraid") class TestClusterEntraIdCredentialsProvider: @pytest.mark.parametrize( "r_credential", diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 95ec5577cc..1f98c5208d 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -16,7 +16,6 @@ from redis.exceptions import ConnectionError, RedisError from redis.retry import Retry from redis.utils import str_if_bytes -from redis_entraid.cred_provider import EntraIdCredentialsProvider from tests.conftest import ( _get_client, get_credential_provider, @@ -24,6 +23,11 @@ skip_if_redis_enterprise, ) +try: + from redis_entraid.cred_provider import EntraIdCredentialsProvider +except ImportError: + EntraIdCredentialsProvider = None + @pytest.fixture() def endpoint(request): @@ -295,6 +299,7 @@ def test_user_pass_provider_only_password(self, r, request): @pytest.mark.onlynoncluster +@pytest.mark.skipif(not EntraIdCredentialsProvider, reason="requires redis-entraid") class TestStreamingCredentialProvider: @pytest.mark.parametrize( "credential_provider", @@ -567,6 +572,7 @@ def test_fails_on_token_renewal(self, credential_provider): @pytest.mark.onlynoncluster @pytest.mark.cp_integration +@pytest.mark.skipif(not EntraIdCredentialsProvider, reason="requires redis-entraid") class TestEntraIdCredentialsProvider: @pytest.mark.parametrize( "r_entra", @@ -637,6 +643,7 @@ def test_auth_pubsub_with_credential_provider(self, r_entra: redis.Redis): @pytest.mark.onlycluster @pytest.mark.cp_integration +@pytest.mark.skipif(not EntraIdCredentialsProvider, reason="requires redis-entraid") class TestClusterEntraIdCredentialsProvider: @pytest.mark.parametrize( "r_entra",