3
3
# Licensed under the MIT License.
4
4
# ------------------------------------
5
5
from binascii import hexlify
6
- from typing import TYPE_CHECKING
6
+ from typing import cast , NamedTuple , TYPE_CHECKING
7
7
8
8
from cryptography import x509
9
9
from cryptography .hazmat .primitives import hashes , serialization
15
15
from .._internal .client_credential_base import ClientCredentialBase
16
16
17
17
if TYPE_CHECKING :
18
+ # pylint:disable=ungrouped-imports
18
19
from typing import Any , Optional , Union
19
20
20
21
@@ -28,13 +29,13 @@ class CertificateCredential(ClientCredentialBase):
28
29
29
30
:param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID.
30
31
:param str client_id: the service principal's client ID
31
- :param str certificate_path: path to a PEM-encoded certificate file including the private key. If not provided,
32
- **certificate_data** is required.
32
+ :param str certificate_path: Optional path to a certificate file in PEM or PKCS12 format, including the private
33
+ key. If not provided, **certificate_data** is required.
33
34
34
- :keyword str authority: Authority of an Azure Active Directory endpoint, for example ' login.microsoftonline.com' ,
35
+ :keyword str authority: Authority of an Azure Active Directory endpoint, for example " login.microsoftonline.com" ,
35
36
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
36
37
defines authorities for other clouds.
37
- :keyword bytes certificate_data: the bytes of a certificate in PEM format, including the private key
38
+ :keyword bytes certificate_data: the bytes of a certificate in PEM or PKCS12 format, including the private key
38
39
:keyword password: The certificate's password. If a unicode string, it will be encoded as UTF-8. If the certificate
39
40
requires a different encoding, pass appropriately encoded bytes instead.
40
41
:paramtype password: str or bytes
@@ -76,6 +77,42 @@ def extract_cert_chain(pem_bytes):
76
77
return b"" .join (chain .splitlines ())
77
78
78
79
80
+ _Cert = NamedTuple ("_Cert" , [("pem_bytes" , bytes ), ("private_key" , "Any" ), ("fingerprint" , bytes )])
81
+
82
+
83
+ def load_pem_certificate (certificate_data , password ):
84
+ # type: (bytes, Optional[bytes]) -> _Cert
85
+ private_key = serialization .load_pem_private_key (certificate_data , password , backend = default_backend ())
86
+ cert = x509 .load_pem_x509_certificate (certificate_data , default_backend ())
87
+ fingerprint = cert .fingerprint (hashes .SHA1 ()) # nosec
88
+ return _Cert (certificate_data , private_key , fingerprint )
89
+
90
+
91
+ def load_pkcs12_certificate (certificate_data , password ):
92
+ # type: (bytes, Optional[bytes]) -> _Cert
93
+ from cryptography .hazmat .primitives .serialization import Encoding , NoEncryption , pkcs12 , PrivateFormat
94
+
95
+ private_key , cert , additional_certs = pkcs12 .load_key_and_certificates (
96
+ certificate_data , password , backend = default_backend ()
97
+ )
98
+ if not private_key :
99
+ raise ValueError ("The certificate must include its private key" )
100
+ if not cert :
101
+ # mentioning PEM here because we raise this error when certificate_data is garbage
102
+ raise ValueError ("Failed to deserialize certificate in PEM or PKCS12 format" )
103
+
104
+ # This serializes the private key without any encryption it may have had. Doing so doesn't violate security
105
+ # boundaries because this representation of the key is kept in memory. We already have the key and its
106
+ # password, if any, in memory.
107
+ key_bytes = private_key .private_bytes (Encoding .PEM , PrivateFormat .PKCS8 , NoEncryption ())
108
+ pem_sections = [key_bytes ] + [c .public_bytes (Encoding .PEM ) for c in [cert ] + additional_certs ]
109
+ pem_bytes = b"" .join (pem_sections )
110
+
111
+ fingerprint = cert .fingerprint (hashes .SHA1 ()) # nosec
112
+
113
+ return _Cert (pem_bytes , private_key , fingerprint )
114
+
115
+
79
116
def get_client_credential (certificate_path , password = None , certificate_data = None , send_certificate_chain = False , ** _ ):
80
117
# type: (Optional[str], Optional[Union[bytes, str]], Optional[bytes], bool, **Any) -> dict
81
118
"""Load a certificate from a filesystem path or bytes, return it as a dict suitable for msal.ClientApplication"""
@@ -88,24 +125,28 @@ def get_client_credential(certificate_path, password=None, certificate_data=None
88
125
elif not certificate_data :
89
126
raise ValueError ('CertificateCredential requires a value for either "certificate_path" or "certificate_data"' )
90
127
91
- if isinstance (password , six .text_type ):
92
- password = password .encode (encoding = "utf-8" )
128
+ if password :
129
+ # if password is already bytes, this won't change its encoding
130
+ password = six .ensure_binary (password , "utf-8" )
131
+ password = cast ("Optional[bytes]" , password )
93
132
94
- private_key = serialization .load_pem_private_key (certificate_data , password = password , backend = default_backend ())
95
- if not isinstance (private_key , RSAPrivateKey ):
96
- raise ValueError ("CertificateCredential requires an RSA private key because it uses RS256 for signing" )
133
+ if certificate_data .startswith (b"-----" ):
134
+ cert = load_pem_certificate (certificate_data , password )
135
+ else :
136
+ cert = load_pkcs12_certificate (certificate_data , password )
137
+ password = None # load_pkcs12_certificate returns cert.pem_bytes decrypted
97
138
98
- cert = x509 . load_pem_x509_certificate ( certificate_data , default_backend ())
99
- fingerprint = cert . fingerprint ( hashes . SHA1 ()) # nosec
139
+ if not isinstance ( cert . private_key , RSAPrivateKey ):
140
+ raise ValueError ( "CertificateCredential requires an RSA private key because it uses RS256 for signing" )
100
141
101
- client_credential = {"private_key" : certificate_data , "thumbprint" : hexlify (fingerprint ).decode ("utf-8" )}
142
+ client_credential = {"private_key" : cert . pem_bytes , "thumbprint" : hexlify (cert . fingerprint ).decode ("utf-8" )}
102
143
if password :
103
144
client_credential ["passphrase" ] = password
104
145
105
146
if send_certificate_chain :
106
147
try :
107
148
# the JWT needs the whole chain but load_pem_x509_certificate deserializes only the signing cert
108
- chain = extract_cert_chain (certificate_data )
149
+ chain = extract_cert_chain (cert . pem_bytes )
109
150
client_credential ["public_certificate" ] = six .ensure_str (chain )
110
151
except ValueError as ex :
111
152
# we shouldn't land here--cryptography already loaded the cert and would have raised if it were malformed
0 commit comments