From b1cee3f754afe702188d86925b17f5fa99bb0b4c Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 28 Jan 2021 14:55:18 -0800 Subject: [PATCH 1/4] accept certs as bytes --- .../identity/_credentials/certificate.py | 72 ++++++++++--------- .../identity/aio/_credentials/certificate.py | 37 ++++++++-- 2 files changed, 70 insertions(+), 39 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py index 49fc06ee0f67..bec722cf8e19 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py @@ -14,7 +14,7 @@ from .._internal.client_credential_base import ClientCredentialBase if TYPE_CHECKING: - from typing import Any + from typing import Any, Optional, Union class CertificateCredential(ClientCredentialBase): @@ -22,11 +22,13 @@ class CertificateCredential(ClientCredentialBase): :param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID. :param str client_id: the service principal's client ID - :param str certificate_path: path to a PEM-encoded certificate file including the private key. + :param str certificate_path: path to a PEM-encoded certificate file including the private key. If not provided, + `certificate_bytes` is required. :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds. + :keyword bytes certificate_bytes: the bytes of a certificate in PEM format, including the private key :keyword password: The certificate's password. If a unicode string, it will be encoded as UTF-8. If the certificate requires a different encoding, pass appropriately encoded bytes instead. :paramtype password: str or bytes @@ -39,37 +41,11 @@ class CertificateCredential(ClientCredentialBase): is unavailable. Default to False. Has no effect when `enable_persistent_cache` is False. """ - def __init__(self, tenant_id, client_id, certificate_path, **kwargs): - # type: (str, str, str, **Any) -> None + def __init__(self, tenant_id, client_id, certificate_path=None, **kwargs): + # type: (str, str, Optional[str], **Any) -> None validate_tenant_id(tenant_id) - if not certificate_path: - raise ValueError( - "'certificate_path' must be the path to a PEM file containing an x509 certificate and its private key" - ) - password = kwargs.pop("password", None) - if isinstance(password, six.text_type): - password = password.encode(encoding="utf-8") - - with open(certificate_path, "rb") as f: - pem_bytes = f.read() - - cert = x509.load_pem_x509_certificate(pem_bytes, default_backend()) - fingerprint = cert.fingerprint(hashes.SHA1()) # nosec - - client_credential = {"private_key": pem_bytes, "thumbprint": hexlify(fingerprint).decode("utf-8")} - if password: - client_credential["passphrase"] = password - - if kwargs.pop("send_certificate_chain", False): - try: - # the JWT needs the whole chain but load_pem_x509_certificate deserializes only the signing cert - chain = extract_cert_chain(pem_bytes) - client_credential["public_certificate"] = six.ensure_str(chain) - except ValueError as ex: - # we shouldn't land here, because load_pem_private_key should have raised when given a malformed file - message = 'Found no PEM encoded certificate in "{}"'.format(certificate_path) - six.raise_from(ValueError(message), ex) + client_credential = get_client_credential(certificate_path, **kwargs) super(CertificateCredential, self).__init__( client_id=client_id, client_credential=client_credential, tenant_id=tenant_id, **kwargs @@ -84,6 +60,38 @@ def extract_cert_chain(pem_bytes): start = pem_bytes.index(b"-----BEGIN CERTIFICATE-----") footer = b"-----END CERTIFICATE-----" end = pem_bytes.rindex(footer) - chain = pem_bytes[start:end + len(footer) + 1] + chain = pem_bytes[start : end + len(footer) + 1] return b"".join(chain.splitlines()) + + +def get_client_credential(certificate_path, password=None, certificate_bytes=None, send_certificate_chain=False, **_): + # type: (Optional[str], Optional[Union[bytes, str]], Optional[bytes], bool, **Any) -> dict + """Load a certificate from a filesystem path or bytes, return it as a dict suitable for msal.ClientApplication""" + + if certificate_path: + with open(certificate_path, "rb") as f: + certificate_bytes = f.read() + elif not certificate_bytes: + raise ValueError('This credential requires a value for "certificate_path" or "certificate_bytes"') + + if isinstance(password, six.text_type): + password = password.encode(encoding="utf-8") + + cert = x509.load_pem_x509_certificate(certificate_bytes, default_backend()) + fingerprint = cert.fingerprint(hashes.SHA1()) # nosec + + client_credential = {"private_key": certificate_bytes, "thumbprint": hexlify(fingerprint).decode("utf-8")} + if password: + client_credential["passphrase"] = password + + if send_certificate_chain: + try: + # the JWT needs the whole chain but load_pem_x509_certificate deserializes only the signing cert + chain = extract_cert_chain(certificate_bytes) + client_credential["public_certificate"] = six.ensure_str(chain) + except ValueError as ex: + # we shouldn't land here--cryptography already loaded the cert and would have raised if it were malformed + six.raise_from(ValueError("Malformed certificate"), ex) + + return client_credential diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py index 68ac0b8a0539..ef56142a76b5 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py @@ -4,30 +4,56 @@ # ------------------------------------ from typing import TYPE_CHECKING +from msal import TokenCache + from .._internal import AadClient, AsyncContextManager from .._internal.decorators import log_get_token_async -from ..._internal import CertificateCredentialBase +from ..._credentials.certificate import get_client_credential +from ..._internal import AadClientCertificate, validate_tenant_id +from ..._internal.persistent_cache import load_service_principal_cache if TYPE_CHECKING: - from typing import Any + from typing import Any, Optional from azure.core.credentials import AccessToken -class CertificateCredential(CertificateCredentialBase, AsyncContextManager): +class CertificateCredential(AsyncContextManager): """Authenticates as a service principal using a certificate. :param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID. :param str client_id: the service principal's client ID - :param str certificate_path: path to a PEM-encoded certificate file including the private key + :param str certificate_path: path to a PEM-encoded certificate file including the private key. If not provided, + `certificate_bytes` is required. :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds. + :keyword bytes certificate_bytes: the bytes of a certificate in PEM format, including the private key :keyword password: The certificate's password. If a unicode string, it will be encoded as UTF-8. If the certificate requires a different encoding, pass appropriately encoded bytes instead. :paramtype password: str or bytes """ + def __init__(self, tenant_id, client_id, certificate_path=None, **kwargs): + # type: (str, str, Optional[str], **Any) -> None + validate_tenant_id(tenant_id) + + client_credential = get_client_credential(certificate_path, **kwargs) + + self._certificate = AadClientCertificate( + client_credential["private_key"], password=client_credential.get("passphrase") + ) + + enable_persistent_cache = kwargs.pop("enable_persistent_cache", False) + if enable_persistent_cache: + allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False) + cache = load_service_principal_cache(allow_unencrypted) + else: + cache = TokenCache() + + self._client = AadClient(tenant_id, client_id, cache=cache, **kwargs) + self._client_id = client_id + async def __aenter__(self): await self._client.__aenter__() return self @@ -61,6 +87,3 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py except Exception: # pylint: disable=broad-except pass return token - - def _get_auth_client(self, tenant_id, client_id, **kwargs): - return AadClient(tenant_id, client_id, **kwargs) From 70f1a7b9756721170fffd4a8fcc77a7258babeac Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 28 Jan 2021 14:58:49 -0800 Subject: [PATCH 2/4] remove unused base class --- .../azure/identity/_internal/__init__.py | 2 - .../_internal/certificate_credential_base.py | 61 ------------------- 2 files changed, 63 deletions(-) delete mode 100644 sdk/identity/azure-identity/azure/identity/_internal/certificate_credential_base.py diff --git a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py index 46b420b5b1a7..da0c1ff1e20a 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py @@ -47,7 +47,6 @@ def validate_tenant_id(tenant_id): from .aad_client_base import AadClientBase from .auth_code_redirect_handler import AuthCodeRedirectServer from .aadclient_certificate import AadClientCertificate -from .certificate_credential_base import CertificateCredentialBase from .client_secret_credential_base import ClientSecretCredentialBase from .decorators import wrap_exceptions from .interactive import InteractiveCredential @@ -72,7 +71,6 @@ def _scopes_to_resource(*scopes): "AadClientBase", "AuthCodeRedirectServer", "AadClientCertificate", - "CertificateCredentialBase", "ClientSecretCredentialBase", "get_default_authority", "InteractiveCredential", diff --git a/sdk/identity/azure-identity/azure/identity/_internal/certificate_credential_base.py b/sdk/identity/azure-identity/azure/identity/_internal/certificate_credential_base.py deleted file mode 100644 index 32524d3d5682..000000000000 --- a/sdk/identity/azure-identity/azure/identity/_internal/certificate_credential_base.py +++ /dev/null @@ -1,61 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -import abc - -from msal import TokenCache -import six - -from . import AadClientCertificate -from .persistent_cache import load_service_principal_cache -from .._internal import validate_tenant_id - -try: - ABC = abc.ABC -except AttributeError: # Python 2.7, abc exists, but not ABC - ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore - -try: - from typing import TYPE_CHECKING -except ImportError: - TYPE_CHECKING = False - -if TYPE_CHECKING: - # pylint:disable=unused-import - from typing import Any - - -class CertificateCredentialBase(ABC): - def __init__(self, tenant_id, client_id, certificate_path, **kwargs): - # type: (str, str, str, **Any) -> None - validate_tenant_id(tenant_id) - if not certificate_path: - raise ValueError( - "'certificate_path' must be the path to a PEM file containing an x509 certificate and its private key" - ) - - super(CertificateCredentialBase, self).__init__() - - password = kwargs.pop("password", None) - if isinstance(password, six.text_type): - password = password.encode(encoding="utf-8") - - with open(certificate_path, "rb") as f: - pem_bytes = f.read() - - self._certificate = AadClientCertificate(pem_bytes, password=password) - - enable_persistent_cache = kwargs.pop("enable_persistent_cache", False) - if enable_persistent_cache: - allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False) - cache = load_service_principal_cache(allow_unencrypted) - else: - cache = TokenCache() - - self._client = self._get_auth_client(tenant_id, client_id, cache=cache, **kwargs) - self._client_id = client_id - - @abc.abstractmethod - def _get_auth_client(self, tenant_id, client_id, **kwargs): - pass From f424f7e53f1e614c77b7015ad7244801cb2bb034 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 28 Jan 2021 15:20:43 -0800 Subject: [PATCH 3/4] add tests --- sdk/identity/azure-identity/conftest.py | 57 ++++++++----------- .../tests/test_certificate_credential.py | 31 ++++++++++ .../test_certificate_credential_async.py | 31 +++++++++- .../azure-identity/tests/test_live.py | 19 +++++-- .../azure-identity/tests/test_live_async.py | 20 ++++--- 5 files changed, 111 insertions(+), 47 deletions(-) diff --git a/sdk/identity/azure-identity/conftest.py b/sdk/identity/azure-identity/conftest.py index c176fc659d34..fbca098472cd 100644 --- a/sdk/identity/azure-identity/conftest.py +++ b/sdk/identity/azure-identity/conftest.py @@ -6,6 +6,7 @@ import sys import pytest +import six from azure.identity._constants import DEVELOPER_SIGN_ON_CLIENT_ID, EnvironmentVariables @@ -62,43 +63,32 @@ def live_service_principal(): # pylint:disable=inconsistent-return-statements @pytest.fixture() -def live_certificate(live_service_principal): # pylint:disable=inconsistent-return-statements,redefined-outer-name - """Provides a path to a PEM-encoded certificate with no password""" - - pem_content = os.environ.get("PEM_CONTENT") - if not pem_content: - pytest.skip("Expected PEM content in environment variable 'PEM_CONTENT'") - return - - pem_path = os.path.join(os.path.dirname(__file__), "certificate.pem") - try: - with open(pem_path, "w") as pem_file: - pem_file.write(pem_content) - return dict(live_service_principal, cert_path=pem_path) - except IOError as ex: - pytest.skip("Failed to write file '{}': {}".format(pem_path, ex)) +def live_certificate(live_service_principal): + content = os.environ.get("PEM_CONTENT") + password_protected_content = os.environ.get("PEM_CONTENT_PASSWORD_PROTECTED") + password = os.environ.get("CERTIFICATE_PASSWORD") + if content and password_protected_content and password: + current_directory = os.path.dirname(__file__) + parameters = { + "cert_bytes": six.ensure_binary(content), + "cert_path": os.path.join(current_directory, "certificate.pem"), + "cert_with_password_bytes": six.ensure_binary(password_protected_content), + "cert_with_password_path": os.path.join(current_directory, "certificate-with-password.pem"), + "password": password, + } -@pytest.fixture() -def live_certificate_with_password(live_service_principal): - """Provides a path to a PEM-encoded, password-protected certificate, and its password""" + try: + with open(parameters["cert_path"], "wb") as f: + f.write(parameters["cert_bytes"]) + with open(parameters["cert_with_password_path"], "wb") as f: + f.write(parameters["cert_with_password_bytes"]) + except IOError as ex: + pytest.skip("Failed to write a file: {}".format(ex)) - pem_content = os.environ.get("PEM_CONTENT_PASSWORD_PROTECTED") - password = os.environ.get("CERTIFICATE_PASSWORD") - if not (pem_content and password): - pytest.skip( - "Expected password-protected PEM content in environment variable 'PEM_CONTENT_PASSWORD_PROTECTED'" - + " and the password in 'CERTIFICATE_PASSWORD'" - ) - return + return dict(live_service_principal, **parameters) - pem_path = os.path.join(os.path.dirname(__file__), "certificate-with-password.pem") - try: - with open(pem_path, "w") as pem_file: - pem_file.write(pem_content) - return dict(live_service_principal, cert_path=pem_path, password=password) - except IOError as ex: - pytest.skip("Failed to write file '{}': {}".format(pem_path, ex)) + pytest.skip("Missing PEM certificate configuration") @pytest.fixture() @@ -114,6 +104,7 @@ def live_user_details(): else: return user_details + @pytest.fixture() def event_loop(): """Ensure the event loop used by pytest-asyncio on Windows is ProactorEventLoop, which supports subprocesses. diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential.py b/sdk/identity/azure-identity/tests/test_certificate_credential.py index d8562fd36094..f2f756faf24d 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential.py @@ -125,6 +125,21 @@ def test_authority(authority): assert kwargs["authority"] == expected_authority +def test_requires_certificate(): + """the credential should raise ValueError when not given a certificate""" + + with pytest.raises(ValueError): + CertificateCredential("tenant", "client-id") + with pytest.raises(ValueError): + CertificateCredential("tenant", "client-id", certificate_path=None) + with pytest.raises(ValueError): + CertificateCredential("tenant", "client-id", certificate_path="") + with pytest.raises(ValueError): + CertificateCredential("tenant", "client-id", certificate_bytes=None) + with pytest.raises(ValueError): + CertificateCredential("tenant", "client-id", certificate_path="", certificate_bytes=None) + + @pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) @pytest.mark.parametrize("send_certificate_chain", (True, False)) def test_request_body(cert_path, cert_password, send_certificate_chain): @@ -158,6 +173,22 @@ def mock_send(request, **kwargs): token = cred.get_token(expected_scope) assert token.token == access_token + # credential should also accept the certificate as bytes + with open(cert_path, "rb") as f: + cert_bytes = f.read() + + cred = CertificateCredential( + tenant_id, + client_id, + certificate_bytes=cert_bytes, + password=cert_password, + transport=Mock(send=mock_send), + authority=authority, + send_certificate_chain=send_certificate_chain, + ) + token = cred.get_token(expected_scope) + assert token.token == access_token + def validate_jwt(request, client_id, pem_bytes, expect_x5c=False): """Validate the request meets AAD's expectations for a client credential grant using a certificate, as documented diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py index 973994be2b9c..61d0ca14ce3c 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py @@ -123,6 +123,21 @@ async def mock_send(request, **kwargs): assert token.token == access_token +def test_requires_certificate(): + """the credential should raise ValueError when not given a certificate""" + + with pytest.raises(ValueError): + CertificateCredential("tenant", "client-id") + with pytest.raises(ValueError): + CertificateCredential("tenant", "client-id", certificate_path=None) + with pytest.raises(ValueError): + CertificateCredential("tenant", "client-id", certificate_path="") + with pytest.raises(ValueError): + CertificateCredential("tenant", "client-id", certificate_bytes=None) + with pytest.raises(ValueError): + CertificateCredential("tenant", "client-id", certificate_path="", certificate_bytes=None) + + @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) async def test_request_body(cert_path, cert_password): @@ -144,8 +159,22 @@ async def mock_send(request, **kwargs): cred = CertificateCredential( tenant_id, client_id, cert_path, password=cert_password, transport=Mock(send=mock_send), authority=authority ) - token = await cred.get_token("scope") + token = await cred.get_token(expected_scope) + assert token.token == access_token + # credential should also accept the certificate as bytes + with open(cert_path, "rb") as f: + cert_bytes = f.read() + + cred = CertificateCredential( + tenant_id, + client_id, + certificate_bytes=cert_bytes, + password=cert_password, + transport=Mock(send=mock_send), + authority=authority, + ) + token = await cred.get_token(expected_scope) assert token.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_live.py b/sdk/identity/azure-identity/tests/test_live.py index ae35e570e87f..cae96fb5c60f 100644 --- a/sdk/identity/azure-identity/tests/test_live.py +++ b/sdk/identity/azure-identity/tests/test_live.py @@ -25,18 +25,25 @@ def get_token(credential): def test_certificate_credential(live_certificate): + tenant_id = live_certificate["tenant_id"] + client_id = live_certificate["client_id"] + + credential = CertificateCredential(tenant_id, client_id, live_certificate["cert_path"]) + get_token(credential) + credential = CertificateCredential( - live_certificate["tenant_id"], live_certificate["client_id"], live_certificate["cert_path"] + tenant_id, client_id, live_certificate["cert_with_password_path"], password=live_certificate["password"] ) get_token(credential) + credential = CertificateCredential(tenant_id, client_id, certificate_bytes=live_certificate["cert_bytes"]) + get_token(credential) -def test_certificate_credential_with_password(live_certificate_with_password): credential = CertificateCredential( - live_certificate_with_password["tenant_id"], - live_certificate_with_password["client_id"], - live_certificate_with_password["cert_path"], - password=live_certificate_with_password["password"], + tenant_id, + client_id, + certificate_bytes=live_certificate["cert_with_password_bytes"], + password=live_certificate["password"], ) get_token(credential) diff --git a/sdk/identity/azure-identity/tests/test_live_async.py b/sdk/identity/azure-identity/tests/test_live_async.py index c85f66020b9a..c125a6da6336 100644 --- a/sdk/identity/azure-identity/tests/test_live_async.py +++ b/sdk/identity/azure-identity/tests/test_live_async.py @@ -18,19 +18,25 @@ async def get_token(credential): @pytest.mark.asyncio async def test_certificate_credential(live_certificate): + tenant_id = live_certificate["tenant_id"] + client_id = live_certificate["client_id"] + + credential = CertificateCredential(tenant_id, client_id, live_certificate["cert_path"]) + await get_token(credential) + credential = CertificateCredential( - live_certificate["tenant_id"], live_certificate["client_id"], live_certificate["cert_path"] + tenant_id, client_id, live_certificate["cert_with_password_path"], password=live_certificate["password"] ) await get_token(credential) + credential = CertificateCredential(tenant_id, client_id, certificate_bytes=live_certificate["cert_bytes"]) + await get_token(credential) -@pytest.mark.asyncio -async def test_certificate_credential_with_password(live_certificate_with_password): credential = CertificateCredential( - live_certificate_with_password["tenant_id"], - live_certificate_with_password["client_id"], - live_certificate_with_password["cert_path"], - password=live_certificate_with_password["password"], + tenant_id, + client_id, + certificate_bytes=live_certificate["cert_with_password_bytes"], + password=live_certificate["password"], ) await get_token(credential) From fdd4c56bb93fa50673b049d10e315fb28232cb5f Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 28 Jan 2021 15:20:53 -0800 Subject: [PATCH 4/4] update changelog --- sdk/identity/azure-identity/CHANGELOG.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 1cef31e82909..65e1113fd7c6 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -1,7 +1,12 @@ # Release History ## 1.5.1 (Unreleased) - +### Added +- `CertificateCredential` can load a certificate from bytes instead of a file + path. To provide a certificate as bytes, use the keyword argument + `certificate_bytes` instead of `certificate_path`, for example: + `CertificateCredential(tenant_id, client_id, certificate_bytes=cert_bytes)` + ([#14055](https://github.com/Azure/azure-sdk-for-python/issues/14055)) ## 1.5.0 (2020-11-11) ### Breaking Changes