diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3653525..0d95909 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,6 +12,7 @@ env: AZURE_CLIENT_SECRET: ${{ secrets.IDP_CLIENT_CREDENTIAL }} AZURE_CLIENT_ID: ${{ secrets.IDP_CLIENT_ID }} AZURE_TENANT_ID: ${{ secrets.IDP_TENANT_ID }} + AZURE_REDIS_SCOPES: ${{ secrets.IDP_SCOPES }} jobs: tests: strategy: diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..6b09f57 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021-2023, Redis, inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md index 06d1394..5f30b10 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,31 @@ from redis_entraid.cred_provider import * ### Step 2 - Create the credential provider via the factory method +Following factory methods are offered depends on authentication type you need: + +`create_from_managed_identity` - Creates a credential provider based on a managed identity. +Managed identities allow Azure services to authenticate without needing explicit credentials, as they are automatically assigned by Azure. + +`create_from_service_principal` - Creates a credential provider using a service principal. +A service principal is typically used when you want to authenticate as an application, rather than as a user, with Azure Active Directory. + +`create_from_default_azure_credential` - Creates a credential provider from a Default Azure Credential. +This method allows automatic selection of the appropriate credential mechanism based on the environment +(e.g., environment variables, managed identities, service principal, interactive browser etc.). + +#### Examples #### + +**Managed Identity** + +```python +credential_provider = create_from_managed_identity( + identity_type=ManagedIdentityType.SYSTEM_ASSIGNED, + resource="https://redis.azure.com/" +) +``` + +**Service principal** + ```python credential_provider = create_from_service_principal( CLIENT_ID, @@ -58,6 +83,17 @@ credential_provider = create_from_service_principal( ) ``` +**Default Azure Credential** + +```python +credential_provider = create_from_default_azure_credential( + ("https://redis.azure.com/.default",), +) +``` + +More examples available in [examples](https://github.com/redis/redis-py-entraid/tree/vv-default-azure-credentials/examples) +folder. + ### Step 3 - Provide optional token renewal configuration The default configuration would be applied, but you're able to customise it. diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/interactive_browser_login.py b/examples/interactive_browser_login.py new file mode 100644 index 0000000..28b5741 --- /dev/null +++ b/examples/interactive_browser_login.py @@ -0,0 +1,31 @@ +# Before run this example you need to configure your EntraID application the following way: +# +# 1. Enable "Allow public client flows" option, under "Authentication" section. +# 2. Add the Redirect URL of the web server that DefaultAzureCredential runs +# By default, uses port 8400, so the default Redirect URL looks like "http://localhost:8400". + +import os + +from redis import Redis +from redis_entraid.cred_provider import create_from_default_azure_credential + +def main(): + + # By default, interactive browser login is excluded so you need to enable it. + credential_provider = create_from_default_azure_credential( + scopes=("user.read",), + app_kwargs={ + "exclude_interactive_browser_credential": False, + "interactive_browser_client_id": os.getenv("AZURE_CLIENT_ID"), + "interactive_browser_tenant_id": os.getenv("AZURE_TENANT_ID"), + } + ) + + # Opens a browser tab. After you'll enter your username/password you'll be authenticated. + # When using Entra ID, Azure enforces TLS on your Redis connection. + client = Redis(host=HOST, port=PORT, ssl=True, ssl_cert_reqs=None, credential_provider=credential_provider) + print("The database size is: {}".format(client.dbsize())) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 25e427e..da98e8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,8 +15,9 @@ dependencies = [ "redis~=5.3.0b3", "PyJWT~=2.9.0", "msal~=1.31.0", + "azure-identity~=1.20.0" ] [tool.setuptools.packages.find] include = ["redis_entraid"] -exclude = ["tests", ".github"] +exclude = ["tests", "examples", ".github"] diff --git a/redis_entraid/cred_provider.py b/redis_entraid/cred_provider.py index b514baa..31c8a6d 100644 --- a/redis_entraid/cred_provider.py +++ b/redis_entraid/cred_provider.py @@ -6,7 +6,8 @@ from redis_entraid.identity_provider import ManagedIdentityType, ManagedIdentityIdType, \ _create_provider_from_managed_identity, ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig, \ - _create_provider_from_service_principal + _create_provider_from_service_principal, DefaultAzureCredentialIdentityProviderConfig, \ + _create_provider_from_default_azure_credential DEFAULT_EXPIRATION_REFRESH_RATIO = 0.7 DEFAULT_LOWER_REFRESH_BOUND_MILLIS = 0 @@ -158,4 +159,42 @@ def create_from_service_principal( token_kwargs=token_kwargs, ) idp = _create_provider_from_service_principal(service_principal_config) + return EntraIdCredentialsProvider(idp, token_manager_config) + + +def create_from_default_azure_credential( + scopes: Tuple[str], + tenant_id: Optional[str] = None, + authority: Optional[str] = None, + token_kwargs: Optional[dict] = {}, + app_kwargs: Optional[dict] = {}, + token_manager_config: Optional[TokenManagerConfig] = TokenManagerConfig( + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + RetryPolicy( + DEFAULT_MAX_ATTEMPTS, + DEFAULT_DELAY_IN_MS + ) + ) +) -> EntraIdCredentialsProvider: + """ + Create a credential provider from a Default Azure credential. + https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python + + :param scopes: Service principal scopes. Fallback to default scopes if None. + :param tenant_id: Optional tenant to include in the token request. + :param authority: Custom authority, by default used 'login.microsoftonline.com' + :param token_kwargs: Optional token arguments applied when retrieving tokens. + :param app_kwargs: Optional keyword arguments to pass when instantiating application. + :param token_manager_config: Token manager specific configuration. + """ + default_azure_credential_config = DefaultAzureCredentialIdentityProviderConfig( + scopes=scopes, + authority=authority, + additional_tenant_id=tenant_id, + token_kwargs=token_kwargs, + app_kwargs=app_kwargs, + ) + idp = _create_provider_from_default_azure_credential(default_azure_credential_config) return EntraIdCredentialsProvider(idp, token_manager_config) \ No newline at end of file diff --git a/redis_entraid/identity_provider.py b/redis_entraid/identity_provider.py index f3a5c45..9f9ec49 100644 --- a/redis_entraid/identity_provider.py +++ b/redis_entraid/identity_provider.py @@ -1,8 +1,9 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Optional, Union, Callable, Any, List +from typing import Optional, Any, List, Tuple, Iterable import requests +from azure.identity import DefaultAzureCredential from msal import ( ConfidentialClientApplication, ManagedIdentityClient, @@ -41,69 +42,120 @@ class ServicePrincipalIdentityProviderConfig: scopes: Optional[List[str]] = None timeout: Optional[float] = None tenant_id: Optional[str] = None - token_kwargs: Optional[dict] = None + token_kwargs: Optional[dict] = field(default_factory=dict) app_kwargs: Optional[dict] = field(default_factory=dict) -class EntraIDIdentityProvider(IdentityProviderInterface): - """ - EntraID Identity Provider implementation. - It's recommended to use an additional factory methods to simplify object instantiation. +@dataclass +class DefaultAzureCredentialIdentityProviderConfig: + scopes: Iterable[str] + additional_tenant_id: Optional[str] = None + authority: Optional[str] = None + token_kwargs: Optional[dict] = field(default_factory=dict) + app_kwargs: Optional[dict] = field(default_factory=dict) + - Methods: create_provider_from_managed_identity, create_provider_from_service_principal. +class ManagedIdentityProvider(IdentityProviderInterface): + """ + Identity Provider implementation for Azure Managed Identity auth type. """ def __init__( self, - app: Union[ManagedIdentityClient, ConfidentialClientApplication], - scopes : List = [], - resource: str = '', + app: ManagedIdentityClient, + resource: str, **kwargs ): + """ + :param kwargs: See: :class:`ManagedIdentityClient` for additional configuration. + """ self._app = app - self._scopes = scopes self._resource = resource self._kwargs = kwargs def request_token(self, force_refresh=False) -> TokenInterface: """ - Request token from identity provider. - Force refresh argument is optional and works only with Service Principal auth method. + Request token from identity provider. Force refresh isn't supported for this provider type. + """ + try: + response = self._app.acquire_token_for_client(resource=self._resource, **self._kwargs) + + if "error" in response: + raise RequestTokenErr(response["error_description"]) + except Exception as e: + raise RequestTokenErr(e) + + return JWToken(response["access_token"]) - :param force_refresh: - :return: TokenInterface + +class ServicePrincipalProvider(IdentityProviderInterface): + """ + Identity Provider implementation for Azure Service Principal auth type. + """ + def __init__( + self, + app: ConfidentialClientApplication, + scopes: Optional[List[str]] = None, + **kwargs + ): + """ + :param kwargs: See: :class:`ConfidentialClientApplication` for additional configuration. """ - if isinstance(self._app, ManagedIdentityClient): - return self._get_token(self._app.acquire_token_for_client, resource=self._resource) + self._app = app + self._scopes = scopes + self._kwargs = kwargs + def request_token(self, force_refresh=False) -> TokenInterface: + """ + Request token from identity provider. + """ if force_refresh: self._app.remove_tokens_for_client() - return self._get_token( - self._app.acquire_token_for_client, - scopes=self._scopes, - **self._kwargs - ) - - def _get_token(self, callback: Callable, **kwargs) -> JWToken: try: - response = callback(**kwargs) + response = self._app.acquire_token_for_client(scopes=self._scopes, **self._kwargs) if "error" in response: raise RequestTokenErr(response["error_description"]) + except Exception as e: + raise RequestTokenErr(e) + + return JWToken(response["access_token"]) + + +class DefaultAzureCredentialProvider(IdentityProviderInterface): + """ + Identity Provider implementation for Default Azure Credential flow. + """ + + def __init__( + self, + app: DefaultAzureCredential, + scopes: Tuple[str], + additional_tenant_id: Optional[str] = None, + **kwargs + ): + self._app = app + self._scopes = scopes + self._additional_tenant_id = additional_tenant_id + self._kwargs = kwargs - return JWToken(callback(**kwargs)["access_token"]) + def request_token(self, force_refresh=False) -> TokenInterface: + try: + response = self._app.get_token(*self._scopes, tenant_id=self._additional_tenant_id, **self._kwargs) except Exception as e: raise RequestTokenErr(e) + return JWToken(response.token) + -def _create_provider_from_managed_identity(config: ManagedIdentityProviderConfig) -> EntraIDIdentityProvider: +def _create_provider_from_managed_identity(config: ManagedIdentityProviderConfig) -> ManagedIdentityProvider: """ - Create an EntraID identity provider following Managed Identity auth flow. + Create a Managed identity provider following Managed Identity auth flow. :param config: Config for managed assigned identity provider See: :class:`ManagedIdentityClient` acquire_token_for_client method. - :return: :class:`EntraIDIdentityProvider` + :return: :class:`ManagedIdentityProvider` """ if config.identity_type == ManagedIdentityType.USER_ASSIGNED: if config.id_type is None or config.id_value == '': @@ -118,16 +170,16 @@ def _create_provider_from_managed_identity(config: ManagedIdentityProviderConfig managed_identity = config.identity_type.value() app = ManagedIdentityClient(managed_identity, http_client=requests.Session()) - return EntraIDIdentityProvider(app, [], config.resource, **config.kwargs) + return ManagedIdentityProvider(app, config.resource, **config.kwargs) -def _create_provider_from_service_principal(config: ServicePrincipalIdentityProviderConfig) -> EntraIDIdentityProvider: +def _create_provider_from_service_principal(config: ServicePrincipalIdentityProviderConfig) -> ServicePrincipalProvider: """ - Create an EntraID identity provider following Service Principal auth flow. + Create a Service Principal identity provider following Service Principal auth flow. :param config: Config for service principal identity provider - :return: :class:`EntraIDIdentityProvider` + :return: :class:`ServicePrincipalProvider` See: :class:`ConfidentialClientApplication`. """ @@ -146,4 +198,23 @@ def _create_provider_from_service_principal(config: ServicePrincipalIdentityProv authority=authority, **config.app_kwargs ) - return EntraIDIdentityProvider(app, scopes, **config.token_kwargs) + return ServicePrincipalProvider(app, scopes, **config.token_kwargs) + + +def _create_provider_from_default_azure_credential( + config: DefaultAzureCredentialIdentityProviderConfig +) -> DefaultAzureCredentialProvider: + """ + Create a Default Azure Credential identity provider following Default Azure Credential flow. + + :param config: Config for default Azure Credential identity provider + :return: :class:`DefaultAzureCredentialProvider` + See: :class:`DefaultAzureCredential`. + """ + + app = DefaultAzureCredential( + authority=config.authority, + **config.app_kwargs + ) + + return DefaultAzureCredentialProvider(app, config.scopes, config.additional_tenant_id, **config.token_kwargs) diff --git a/requirements.txt b/requirements.txt index 6574314..b55b0b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ PyJWT~=2.9.0 msal~=1.31.0 +azure-identity~=1.20.0 redis==5.3.0b4 requests~=2.32.3 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 1455563..26a1ce9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,18 +8,26 @@ from redis_entraid.cred_provider import DEFAULT_EXPIRATION_REFRESH_RATIO, \ DEFAULT_LOWER_REFRESH_BOUND_MILLIS, DEFAULT_MAX_ATTEMPTS, DEFAULT_DELAY_IN_MS, \ - DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, create_from_service_principal, create_from_managed_identity -from redis_entraid.identity_provider import ManagedIdentityType, EntraIDIdentityProvider, ManagedIdentityIdType, \ + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, create_from_service_principal, create_from_managed_identity, \ + create_from_default_azure_credential +from redis_entraid.identity_provider import ManagedIdentityType, ManagedIdentityIdType, \ ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig, _create_provider_from_managed_identity, \ - _create_provider_from_service_principal + _create_provider_from_service_principal, \ + DefaultAzureCredentialIdentityProviderConfig, _create_provider_from_default_azure_credential, \ + ManagedIdentityProvider, ServicePrincipalProvider, DefaultAzureCredentialProvider class AuthType(Enum): MANAGED_IDENTITY = "managed_identity" SERVICE_PRINCIPAL = "service_principal" + DEFAULT_AZURE_CREDENTIAL = "default_azure_credential" -def get_identity_provider_config(request) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: +def get_identity_provider_config(request) -> Union[ + ManagedIdentityProviderConfig, + ServicePrincipalIdentityProviderConfig, + DefaultAzureCredentialIdentityProviderConfig +]: if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) else: @@ -30,6 +38,9 @@ def get_identity_provider_config(request) -> Union[ManagedIdentityProviderConfig if auth_type == AuthType.MANAGED_IDENTITY: return _get_managed_identity_provider_config(request) + if auth_type == AuthType.DEFAULT_AZURE_CREDENTIAL: + return _get_default_azure_credential_provider_config(request) + return _get_service_principal_provider_config(request) @@ -82,6 +93,27 @@ def _get_service_principal_provider_config(request) -> ServicePrincipalIdentityP app_kwargs=kwargs ) + +def _get_default_azure_credential_provider_config(request) -> DefaultAzureCredentialIdentityProviderConfig: + scopes = os.getenv("AZURE_REDIS_SCOPES", ()) + + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + token_kwargs = request.param.get("token_kwargs", {}) + else: + kwargs = {} + token_kwargs = {} + + if isinstance(scopes, str): + scopes = scopes.split(',') + + return DefaultAzureCredentialIdentityProviderConfig( + scopes=scopes, + app_kwargs=kwargs, + token_kwargs=token_kwargs + ) + + def get_credential_provider(request) -> CredentialProvider: if hasattr(request, "param"): cred_provider_kwargs = request.param.get("cred_provider_kwargs", {}) @@ -124,6 +156,16 @@ def get_credential_provider(request) -> CredentialProvider: token_mgr_config, ) + if isinstance(idp_config, DefaultAzureCredentialIdentityProviderConfig): + return create_from_default_azure_credential( + idp_config.scopes, + idp_config.additional_tenant_id, + idp_config.authority, + idp_config.token_kwargs, + idp_config.app_kwargs, + token_mgr_config, + ) + return create_from_managed_identity( idp_config.identity_type, idp_config.resource, @@ -139,11 +181,18 @@ def credential_provider(request) -> CredentialProvider: return get_credential_provider(request) @pytest.fixture() -def identity_provider(request) -> EntraIDIdentityProvider: +def identity_provider(request) -> Union[ + ManagedIdentityProvider, + ServicePrincipalProvider, + DefaultAzureCredentialProvider +]: config = _identity_provider_config(request) if isinstance(config, ManagedIdentityProviderConfig): return _create_provider_from_managed_identity(config) + if isinstance(config, DefaultAzureCredentialIdentityProviderConfig): + return _create_provider_from_default_azure_credential(config) + return _create_provider_from_service_principal(config) def _identity_provider_config(request) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: diff --git a/tests/test_cred_provider.py b/tests/test_cred_provider.py index f7b49b4..5571f8e 100644 --- a/tests/test_cred_provider.py +++ b/tests/test_cred_provider.py @@ -46,6 +46,35 @@ def test_get_credentials_managed_identity(self, credential_provider: EntraIdCred credentials = credential_provider.get_credentials() assert len(credentials) == 2 + @pytest.mark.parametrize( + "credential_provider", + [ + { + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, + ], + ids=["Default Azure Credentials (via EnvironmentCredential)"], + indirect=True, + ) + def test_get_credentials_default_azure_credential_env(self, credential_provider: EntraIdCredentialsProvider): + credentials = credential_provider.get_credentials() + assert len(credentials) == 2 + + @pytest.mark.parametrize( + "credential_provider", + [ + { + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, + ], + ids=["Default Azure Credentials (via ManagedIdentityCredential)"], + indirect=True, + ) + @pytest.mark.managed_identity + def test_get_credentials_default_azure_credential_managed(self, credential_provider: EntraIdCredentialsProvider): + credentials = credential_provider.get_credentials() + assert len(credentials) == 2 + @pytest.mark.parametrize( "credential_provider", @@ -141,7 +170,7 @@ def on_next(token: TokenInterface): "credential_provider", [ { - "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00002}, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00003}, } ], indirect=True, diff --git a/tests/test_identity_provider.py b/tests/test_identity_provider.py index 635cb3d..471e409 100644 --- a/tests/test_identity_provider.py +++ b/tests/test_identity_provider.py @@ -1,13 +1,26 @@ import pytest from msal import TokenCache -from redis_entraid.identity_provider import EntraIDIdentityProvider +from redis_entraid.identity_provider import ServicePrincipalProvider, DefaultAzureCredentialProvider +from tests.conftest import AuthType class TestEntraIDIdentityProvider: CUSTOM_CACHE = TokenCache() - def test_request_token_from_service_principal_identity(self, identity_provider: EntraIDIdentityProvider): + def test_request_token_from_service_principal_identity(self, identity_provider: ServicePrincipalProvider): + assert identity_provider.request_token(force_refresh=True) + + @pytest.mark.parametrize( + "identity_provider", + [ + { + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + } + ], + indirect=True, + ) + def test_request_token_from_default_azure_credential(self, identity_provider: DefaultAzureCredentialProvider): assert identity_provider.request_token(force_refresh=True) @pytest.mark.parametrize( @@ -19,7 +32,7 @@ def test_request_token_from_service_principal_identity(self, identity_provider: ], indirect=True, ) - def test_request_token_caches_token_after_initial_request(self, identity_provider: EntraIDIdentityProvider): + def test_request_token_caches_token_after_initial_request(self, identity_provider: ServicePrincipalProvider): assert len(list(self.CUSTOM_CACHE.search(TokenCache.CredentialType.ACCESS_TOKEN))) == 0 token = identity_provider.request_token()