Skip to content

(tests): Added testing for auth via DefaultAzureCredential #3544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ ujson>=4.2.0
uvloop
vulture>=2.3.0
numpy>=1.24.0
redis-entraid==0.3.0b1
redis-entraid==0.4.0b2
39 changes: 36 additions & 3 deletions tests/entraid_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
ServicePrincipalIdentityProviderConfig,
_create_provider_from_managed_identity,
_create_provider_from_service_principal,
DefaultAzureCredentialIdentityProviderConfig,
_create_provider_from_default_azure_credential,
)
from tests.conftest import mock_identity_provider


class AuthType(Enum):
MANAGED_IDENTITY = "managed_identity"
SERVICE_PRINCIPAL = "service_principal"
DEFAULT_AZURE_CREDENTIAL = "default_azure_credential"


def identity_provider(request) -> IdentityProviderInterface:
Expand All @@ -37,18 +40,25 @@ def identity_provider(request) -> IdentityProviderInterface:
if request.param.get("mock_idp", None) is not None:
return mock_identity_provider()

auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
auth_type = kwargs.get("auth_type", AuthType.SERVICE_PRINCIPAL)
config = get_identity_provider_config(request=request)

if auth_type == "MANAGED_IDENTITY":
if auth_type == AuthType.MANAGED_IDENTITY:
return _create_provider_from_managed_identity(config)

if auth_type == AuthType.DEFAULT_AZURE_CREDENTIAL:
return _create_provider_from_default_azure_credential(config)

return _create_provider_from_service_principal(config)


def get_identity_provider_config(
request,
) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
) -> Union[
ManagedIdentityProviderConfig,
ServicePrincipalIdentityProviderConfig,
DefaultAzureCredentialIdentityProviderConfig,
]:
if hasattr(request, "param"):
kwargs = request.param.get("idp_kwargs", {})
else:
Expand All @@ -59,6 +69,9 @@ def get_identity_provider_config(
if auth_type == AuthType.MANAGED_IDENTITY:
return _get_managed_identity_provider_config(request)

if auth_type == AuthType.DEFAULT_AZURE_CREDENTIAL:
return _get_default_azure_credential_provider_config(request)

return _get_service_principal_provider_config(request)


Expand Down Expand Up @@ -114,6 +127,26 @@ def _get_service_principal_provider_config(
)


def _get_default_azure_credential_provider_config(
request,
) -> DefaultAzureCredentialIdentityProviderConfig:
scopes = os.getenv("AZURE_REDIS_SCOPES", ())

if hasattr(request, "param"):
kwargs = request.param.get("idp_kwargs", {})
token_kwargs = request.param.get("token_kwargs", {})
else:
kwargs = {}
token_kwargs = {}

if isinstance(scopes, str):
scopes = scopes.split(",")

return DefaultAzureCredentialIdentityProviderConfig(
scopes=scopes, app_kwargs=kwargs, token_kwargs=token_kwargs
)


def get_entra_id_credentials_provider(request, cred_provider_kwargs):
idp = identity_provider(request)
expiration_refresh_ratio = cred_provider_kwargs.get(
Expand Down
6 changes: 0 additions & 6 deletions tests/test_asyncio/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import random
from contextlib import asynccontextmanager as _asynccontextmanager
from enum import Enum
from typing import Union

import pytest
Expand All @@ -18,11 +17,6 @@
from .compat import mock


class AuthType(Enum):
MANAGED_IDENTITY = "managed_identity"
SERVICE_PRINCIPAL = "service_principal"


async def _get_info(redis_url):
client = redis.Redis.from_url(redis_url)
info = await client.info()
Expand Down
13 changes: 11 additions & 2 deletions tests/test_asyncio/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from redis.exceptions import ConnectionError
from redis.utils import str_if_bytes
from tests.conftest import get_endpoint, skip_if_redis_enterprise
from tests.entraid_utils import AuthType
from tests.test_asyncio.conftest import get_credential_provider

try:
Expand Down Expand Up @@ -616,8 +617,12 @@ class TestEntraIdCredentialsProvider:
"cred_provider_class": EntraIdCredentialsProvider,
"cred_provider_kwargs": {"block_for_initial": True},
},
{
"cred_provider_class": EntraIdCredentialsProvider,
"idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL},
},
],
ids=["blocked", "non-blocked"],
ids=["blocked", "non-blocked", "DefaultAzureCredential"],
indirect=True,
)
@pytest.mark.asyncio
Expand Down Expand Up @@ -692,8 +697,12 @@ class TestClusterEntraIdCredentialsProvider:
"cred_provider_class": EntraIdCredentialsProvider,
"cred_provider_kwargs": {"block_for_initial": True},
},
{
"cred_provider_class": EntraIdCredentialsProvider,
"idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL},
},
],
ids=["blocked", "non-blocked"],
ids=["blocked", "non-blocked", "DefaultAzureCredential"],
indirect=True,
)
@pytest.mark.asyncio
Expand Down
13 changes: 11 additions & 2 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_endpoint,
skip_if_redis_enterprise,
)
from tests.entraid_utils import AuthType

try:
from redis_entraid.cred_provider import EntraIdCredentialsProvider
Expand Down Expand Up @@ -585,8 +586,12 @@ class TestEntraIdCredentialsProvider:
"cred_provider_class": EntraIdCredentialsProvider,
"single_connection_client": True,
},
{
"cred_provider_class": EntraIdCredentialsProvider,
"idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL},
},
],
ids=["pool", "single"],
ids=["pool", "single", "DefaultAzureCredential"],
indirect=True,
)
@pytest.mark.onlynoncluster
Expand Down Expand Up @@ -656,8 +661,12 @@ class TestClusterEntraIdCredentialsProvider:
"cred_provider_class": EntraIdCredentialsProvider,
"single_connection_client": True,
},
{
"cred_provider_class": EntraIdCredentialsProvider,
"idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL},
},
],
ids=["pool", "single"],
ids=["pool", "single", "DefaultAzureCredential"],
indirect=True,
)
@pytest.mark.onlycluster
Expand Down
Loading