diff --git a/src/aws_encryption_sdk/internal/formatting/serialize.py b/src/aws_encryption_sdk/internal/formatting/serialize.py index e7c86a0cb..bd71b3a1a 100644 --- a/src/aws_encryption_sdk/internal/formatting/serialize.py +++ b/src/aws_encryption_sdk/internal/formatting/serialize.py @@ -316,6 +316,6 @@ def serialize_wrapped_key(key_provider, wrapping_algorithm, wrapping_key_id, enc ) key_ciphertext = encrypted_wrapped_key.ciphertext + encrypted_wrapped_key.tag return EncryptedDataKey( - key_provider=MasterKeyInfo(provider_id=key_provider.provider_id, key_info=key_info), + key_provider=MasterKeyInfo(provider_id=key_provider.provider_id, key_info=key_info, key_name=wrapping_key_id), encrypted_data_key=key_ciphertext, ) diff --git a/src/aws_encryption_sdk/keyrings/aws_kms/__init__.py b/src/aws_encryption_sdk/keyrings/aws_kms/__init__.py index 3630644c6..b65a5a8d2 100644 --- a/src/aws_encryption_sdk/keyrings/aws_kms/__init__.py +++ b/src/aws_encryption_sdk/keyrings/aws_kms/__init__.py @@ -180,16 +180,17 @@ class _AwsKmsSingleCmkKeyring(Keyring): def on_encrypt(self, encryption_materials): # type: (EncryptionMaterials) -> EncryptionMaterials trace_info = MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=self._key_id) + new_materials = encryption_materials try: - if encryption_materials.data_encryption_key is None: + if new_materials.data_encryption_key is None: plaintext_key, encrypted_key = _do_aws_kms_generate_data_key( client_supplier=self._client_supplier, key_name=self._key_id, - encryption_context=encryption_materials.encryption_context, - algorithm=encryption_materials.algorithm, + encryption_context=new_materials.encryption_context, + algorithm=new_materials.algorithm, grant_tokens=self._grant_tokens, ) - encryption_materials.add_data_encryption_key( + new_materials = new_materials.with_data_encryption_key( data_encryption_key=plaintext_key, keyring_trace=KeyringTrace(wrapping_key=trace_info, flags=_GENERATE_FLAGS), ) @@ -197,8 +198,8 @@ def on_encrypt(self, encryption_materials): encrypted_key = _do_aws_kms_encrypt( client_supplier=self._client_supplier, key_name=self._key_id, - plaintext_data_key=encryption_materials.data_encryption_key, - encryption_context=encryption_materials.encryption_context, + plaintext_data_key=new_materials.data_encryption_key, + encryption_context=new_materials.encryption_context, grant_tokens=self._grant_tokens, ) except Exception: # pylint: disable=broad-except @@ -207,30 +208,30 @@ def on_encrypt(self, encryption_materials): _LOGGER.exception(message) raise EncryptKeyError(message) - encryption_materials.add_encrypted_data_key( + return new_materials.with_encrypted_data_key( encrypted_data_key=encrypted_key, keyring_trace=KeyringTrace(wrapping_key=trace_info, flags=_ENCRYPT_FLAGS) ) - return encryption_materials - def on_decrypt(self, decryption_materials, encrypted_data_keys): # type: (DecryptionMaterials, Iterable[EncryptedDataKey]) -> DecryptionMaterials + new_materials = decryption_materials + for edk in encrypted_data_keys: - if decryption_materials.data_encryption_key is not None: - return decryption_materials + if new_materials.data_encryption_key is not None: + return new_materials if ( edk.key_provider.provider_id == _PROVIDER_ID and edk.key_provider.key_info.decode("utf-8") == self._key_id ): - decryption_materials = _try_aws_kms_decrypt( + new_materials = _try_aws_kms_decrypt( client_supplier=self._client_supplier, - decryption_materials=decryption_materials, + decryption_materials=new_materials, grant_tokens=self._grant_tokens, encrypted_data_key=edk, ) - return decryption_materials + return new_materials @attr.s @@ -258,19 +259,21 @@ def on_encrypt(self, encryption_materials): def on_decrypt(self, decryption_materials, encrypted_data_keys): # type: (DecryptionMaterials, Iterable[EncryptedDataKey]) -> DecryptionMaterials + new_materials = decryption_materials + for edk in encrypted_data_keys: - if decryption_materials.data_encryption_key is not None: - return decryption_materials + if new_materials.data_encryption_key is not None: + return new_materials if edk.key_provider.provider_id == _PROVIDER_ID: - decryption_materials = _try_aws_kms_decrypt( + new_materials = _try_aws_kms_decrypt( client_supplier=self._client_supplier, - decryption_materials=decryption_materials, + decryption_materials=new_materials, grant_tokens=self._grant_tokens, encrypted_data_key=edk, ) - return decryption_materials + return new_materials def _try_aws_kms_decrypt(client_supplier, decryption_materials, grant_tokens, encrypted_data_key): @@ -293,14 +296,12 @@ def _try_aws_kms_decrypt(client_supplier, decryption_materials, grant_tokens, en except Exception: # pylint: disable=broad-except # We intentionally WANT to catch all exceptions here _LOGGER.exception("Unable to decrypt encrypted data key from %s", encrypted_data_key.key_provider) - else: - decryption_materials.add_data_encryption_key( - data_encryption_key=plaintext_key, - keyring_trace=KeyringTrace(wrapping_key=encrypted_data_key.key_provider, flags=_DECRYPT_FLAGS), - ) return decryption_materials - return decryption_materials + return decryption_materials.with_data_encryption_key( + data_encryption_key=plaintext_key, + keyring_trace=KeyringTrace(wrapping_key=encrypted_data_key.key_provider, flags=_DECRYPT_FLAGS), + ) def _do_aws_kms_decrypt(client_supplier, key_name, encrypted_data_key, encryption_context, grant_tokens): diff --git a/src/aws_encryption_sdk/keyrings/multi.py b/src/aws_encryption_sdk/keyrings/multi.py index d42ea365d..27e90c3c8 100644 --- a/src/aws_encryption_sdk/keyrings/multi.py +++ b/src/aws_encryption_sdk/keyrings/multi.py @@ -67,20 +67,21 @@ def on_encrypt(self, encryption_materials): "and encryption materials do not already contain a plaintext data key." ) + new_materials = encryption_materials + # Call on_encrypt on the generator keyring if it is provided if self.generator is not None: - - encryption_materials = self.generator.on_encrypt(encryption_materials=encryption_materials) + new_materials = self.generator.on_encrypt(encryption_materials=new_materials) # Check if data key is generated - if encryption_materials.data_encryption_key is None: + if new_materials.data_encryption_key is None: raise GenerateKeyError("Unable to generate data encryption key.") # Call on_encrypt on all other keyrings for keyring in self.children: - encryption_materials = keyring.on_encrypt(encryption_materials=encryption_materials) + new_materials = keyring.on_encrypt(encryption_materials=new_materials) - return encryption_materials + return new_materials def on_decrypt(self, decryption_materials, encrypted_data_keys): # type: (DecryptionMaterials, Iterable[EncryptedDataKey]) -> DecryptionMaterials @@ -92,10 +93,13 @@ def on_decrypt(self, decryption_materials, encrypted_data_keys): :rtype: DecryptionMaterials """ # Call on_decrypt on all keyrings till decryption is successful + new_materials = decryption_materials for keyring in self._decryption_keyrings: - if decryption_materials.data_encryption_key is not None: - return decryption_materials - decryption_materials = keyring.on_decrypt( - decryption_materials=decryption_materials, encrypted_data_keys=encrypted_data_keys + if new_materials.data_encryption_key is not None: + return new_materials + + new_materials = keyring.on_decrypt( + decryption_materials=new_materials, encrypted_data_keys=encrypted_data_keys ) - return decryption_materials + + return new_materials diff --git a/src/aws_encryption_sdk/keyrings/raw.py b/src/aws_encryption_sdk/keyrings/raw.py index fcf793f5f..deab1a122 100644 --- a/src/aws_encryption_sdk/keyrings/raw.py +++ b/src/aws_encryption_sdk/keyrings/raw.py @@ -11,7 +11,7 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey -from aws_encryption_sdk.exceptions import GenerateKeyError +from aws_encryption_sdk.exceptions import EncryptKeyError, GenerateKeyError from aws_encryption_sdk.identifiers import EncryptionKeyType, KeyringTraceFlag, WrappingAlgorithm from aws_encryption_sdk.internal.crypto.wrapping_keys import EncryptedData, WrappingKey from aws_encryption_sdk.internal.formatting.deserialize import deserialize_wrapped_key @@ -35,12 +35,13 @@ def _generate_data_key( encryption_materials, # type: EncryptionMaterials key_provider, # type: MasterKeyInfo ): - # type: (...) -> bytes + # type: (...) -> EncryptionMaterials """Generates plaintext data key for the keyring. :param EncryptionMaterials encryption_materials: Encryption materials for the keyring to modify. :param MasterKeyInfo key_provider: Information about the key in the keyring. - :return bytes: Plaintext data key + :rtype: EncryptionMaterials + :returns: Encryption materials containing a data encryption key """ # Check if encryption materials contain data encryption key if encryption_materials.data_encryption_key is not None: @@ -60,10 +61,9 @@ def _generate_data_key( # plaintext_data_key to RawDataKey data_encryption_key = RawDataKey(key_provider=key_provider, data_key=plaintext_data_key) - # Add generated data key to encryption_materials - encryption_materials.add_data_encryption_key(data_encryption_key, keyring_trace) - - return plaintext_data_key + return encryption_materials.with_data_encryption_key( + data_encryption_key=data_encryption_key, keyring_trace=keyring_trace + ) @attr.s @@ -123,17 +123,20 @@ def on_encrypt(self, encryption_materials): """Generate a data key if not present and encrypt it using any available wrapping key :param EncryptionMaterials encryption_materials: Encryption materials for the keyring to modify - :returns: Optionally modified encryption materials + :returns: Encryption materials containing data key and encrypted data key :rtype: EncryptionMaterials """ - if encryption_materials.data_encryption_key is None: - _generate_data_key(encryption_materials=encryption_materials, key_provider=self._key_provider) + new_materials = encryption_materials + + if new_materials.data_encryption_key is None: + # Get encryption materials with a new data key. + new_materials = _generate_data_key(encryption_materials=new_materials, key_provider=self._key_provider) try: # Encrypt data key encrypted_wrapped_key = self._wrapping_key_structure.encrypt( - plaintext_data_key=encryption_materials.data_encryption_key.data_key, - encryption_context=encryption_materials.encryption_context, + plaintext_data_key=new_materials.data_encryption_key.data_key, + encryption_context=new_materials.encryption_context, ) # EncryptedData to EncryptedDataKey @@ -144,18 +147,17 @@ def on_encrypt(self, encryption_materials): encrypted_wrapped_key=encrypted_wrapped_key, ) except Exception: # pylint: disable=broad-except - error_message = "Raw AES Keyring unable to encrypt data key" + error_message = "Raw AES keyring unable to encrypt data key" _LOGGER.exception(error_message) - return encryption_materials + raise EncryptKeyError(error_message) # Update Keyring Trace keyring_trace = KeyringTrace( - wrapping_key=encrypted_data_key.key_provider, flags={KeyringTraceFlag.ENCRYPTED_DATA_KEY} + wrapping_key=self._key_provider, + flags={KeyringTraceFlag.ENCRYPTED_DATA_KEY, KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT}, ) - # Add encrypted data key to encryption_materials - encryption_materials.add_encrypted_data_key(encrypted_data_key=encrypted_data_key, keyring_trace=keyring_trace) - return encryption_materials + return new_materials.with_encrypted_data_key(encrypted_data_key=encrypted_data_key, keyring_trace=keyring_trace) def on_decrypt(self, decryption_materials, encrypted_data_keys): # type: (DecryptionMaterials, Iterable[EncryptedDataKey]) -> DecryptionMaterials @@ -163,19 +165,18 @@ def on_decrypt(self, decryption_materials, encrypted_data_keys): :param DecryptionMaterials decryption_materials: Decryption materials for the keyring to modify :param List[EncryptedDataKey] encrypted_data_keys: List of encrypted data keys - :returns: Optionally modified decryption materials + :returns: Decryption materials that MAY include a plaintext data key :rtype: DecryptionMaterials """ - if decryption_materials.data_encryption_key is not None: - return decryption_materials + new_materials = decryption_materials + + if new_materials.data_encryption_key is not None: + return new_materials # Decrypt data key expected_key_info_len = len(self._key_info_prefix) + self._wrapping_algorithm.algorithm.iv_len for key in encrypted_data_keys: - if decryption_materials.data_encryption_key is not None: - return decryption_materials - if ( key.key_provider.provider_id != self._key_provider.provider_id or len(key.key_provider.key_info) != expected_key_info_len @@ -192,22 +193,31 @@ def on_decrypt(self, decryption_materials, encrypted_data_keys): try: plaintext_data_key = self._wrapping_key_structure.decrypt( encrypted_wrapped_data_key=encrypted_wrapped_key, - encryption_context=decryption_materials.encryption_context, + encryption_context=new_materials.encryption_context, ) except Exception: # pylint: disable=broad-except + # We intentionally WANT to catch all exceptions here error_message = "Raw AES Keyring unable to decrypt data key" _LOGGER.exception(error_message) - return decryption_materials + # The Raw AES keyring MUST evaluate every encrypted data key + # until it either succeeds or runs out of encrypted data keys. + continue # Create a keyring trace - keyring_trace = KeyringTrace(wrapping_key=self._key_provider, flags={KeyringTraceFlag.DECRYPTED_DATA_KEY}) + keyring_trace = KeyringTrace( + wrapping_key=self._key_provider, + flags={KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT}, + ) # Update decryption materials data_encryption_key = RawDataKey(key_provider=self._key_provider, data_key=plaintext_data_key) - decryption_materials.add_data_encryption_key(data_encryption_key, keyring_trace) - return decryption_materials + return new_materials.with_data_encryption_key( + data_encryption_key=data_encryption_key, keyring_trace=keyring_trace + ) + + return new_materials @attr.s @@ -331,22 +341,24 @@ def on_encrypt(self, encryption_materials): and encrypt it using any available wrapping key in any child keyring. :param EncryptionMaterials encryption_materials: Encryption materials for keyring to modify. - :returns: Optionally modified encryption materials. + :returns: Encryption materials containing data key and encrypted data key :rtype: EncryptionMaterials """ - if encryption_materials.data_encryption_key is None: - _generate_data_key(encryption_materials=encryption_materials, key_provider=self._key_provider) + new_materials = encryption_materials + + if new_materials.data_encryption_key is None: + new_materials = _generate_data_key(encryption_materials=new_materials, key_provider=self._key_provider) if self._public_wrapping_key is None: - return encryption_materials + # This should be impossible, but just in case, give a useful error message. + raise EncryptKeyError("Raw RSA keyring unable to encrypt data key: no public key available") try: # Encrypt data key encrypted_wrapped_key = EncryptedData( iv=None, ciphertext=self._public_wrapping_key.encrypt( - plaintext=encryption_materials.data_encryption_key.data_key, - padding=self._wrapping_algorithm.padding, + plaintext=new_materials.data_encryption_key.data_key, padding=self._wrapping_algorithm.padding, ), tag=None, ) @@ -359,19 +371,15 @@ def on_encrypt(self, encryption_materials): encrypted_wrapped_key=encrypted_wrapped_key, ) except Exception: # pylint: disable=broad-except - error_message = "Raw RSA Keyring unable to encrypt data key" + error_message = "Raw RSA keyring unable to encrypt data key" _LOGGER.exception(error_message) - return encryption_materials + raise EncryptKeyError(error_message) # Update Keyring Trace - keyring_trace = KeyringTrace( - wrapping_key=encrypted_data_key.key_provider, flags={KeyringTraceFlag.ENCRYPTED_DATA_KEY} - ) + keyring_trace = KeyringTrace(wrapping_key=self._key_provider, flags={KeyringTraceFlag.ENCRYPTED_DATA_KEY}) # Add encrypted data key to encryption_materials - encryption_materials.add_encrypted_data_key(encrypted_data_key=encrypted_data_key, keyring_trace=keyring_trace) - - return encryption_materials + return new_materials.with_encrypted_data_key(encrypted_data_key=encrypted_data_key, keyring_trace=keyring_trace) def on_decrypt(self, decryption_materials, encrypted_data_keys): # type: (DecryptionMaterials, Iterable[EncryptedDataKey]) -> DecryptionMaterials @@ -380,18 +388,22 @@ def on_decrypt(self, decryption_materials, encrypted_data_keys): :param DecryptionMaterials decryption_materials: Decryption materials for keyring to modify. :param encrypted_data_keys: List of encrypted data keys. :type: List[EncryptedDataKey] - :returns: Optionally modified decryption materials. + :returns: Decryption materials that MAY include a plaintext data key :rtype: DecryptionMaterials """ + new_materials = decryption_materials + + if new_materials.data_encryption_key is not None: + return new_materials + if self._private_wrapping_key is None: - return decryption_materials + return new_materials # Decrypt data key for key in encrypted_data_keys: - if decryption_materials.data_encryption_key is not None: - return decryption_materials if key.key_provider != self._key_provider: continue + # Wrapped EncryptedDataKey to deserialized EncryptedData encrypted_wrapped_key = deserialize_wrapped_key( wrapping_algorithm=self._wrapping_algorithm, wrapping_key_id=self.key_name, wrapped_encrypted_key=key @@ -403,6 +415,8 @@ def on_decrypt(self, decryption_materials, encrypted_data_keys): except Exception: # pylint: disable=broad-except error_message = "Raw RSA Keyring unable to decrypt data key" _LOGGER.exception(error_message) + # The Raw RSA keyring MUST evaluate every encrypted data key + # until it either succeeds or runs out of encrypted data keys. continue # Create a keyring trace @@ -410,6 +424,9 @@ def on_decrypt(self, decryption_materials, encrypted_data_keys): # Update decryption materials data_encryption_key = RawDataKey(key_provider=self._key_provider, data_key=plaintext_data_key) - decryption_materials.add_data_encryption_key(data_encryption_key, keyring_trace) - return decryption_materials + return new_materials.with_data_encryption_key( + data_encryption_key=data_encryption_key, keyring_trace=keyring_trace + ) + + return new_materials diff --git a/src/aws_encryption_sdk/materials_managers/__init__.py b/src/aws_encryption_sdk/materials_managers/__init__.py index 28ea40c0d..8c8c33886 100644 --- a/src/aws_encryption_sdk/materials_managers/__init__.py +++ b/src/aws_encryption_sdk/materials_managers/__init__.py @@ -14,6 +14,8 @@ .. versionadded:: 1.3.0 """ +import copy + import attr import six from attr.validators import deep_iterable, deep_mapping, instance_of, optional @@ -117,6 +119,8 @@ def _validate_data_encryption_key(self, data_encryption_key, keyring_trace, requ # type: (Union[DataKey, RawDataKey], KeyringTrace, Iterable[KeyringTraceFlag]) -> None """Validate that the provided data encryption key and keyring trace match for each other and the materials. + .. versionadded:: 1.5.0 + :param RawDataKey data_encryption_key: Data encryption key :param KeyringTrace keyring_trace: Keyring trace corresponding to data_encryption_key :param required_flags: Iterable of required flags @@ -143,9 +147,11 @@ def _validate_data_encryption_key(self, data_encryption_key, keyring_trace, requ ) ) - def _add_data_encryption_key(self, data_encryption_key, keyring_trace, required_flags): - # type: (Union[DataKey, RawDataKey], KeyringTrace, Iterable[KeyringTraceFlag]) -> None - """Add a plaintext data encryption key. + def _with_data_encryption_key(self, data_encryption_key, keyring_trace, required_flags): + # type: (Union[DataKey, RawDataKey], KeyringTrace, Iterable[KeyringTraceFlag]) -> CryptographicMaterials + """Get new cryptographic materials that include this data encryption key. + + .. versionadded:: 1.5.0 :param RawDataKey data_encryption_key: Data encryption key :param KeyringTrace keyring_trace: Trace of actions that a keyring performed @@ -161,10 +167,15 @@ def _add_data_encryption_key(self, data_encryption_key, keyring_trace, required_ data_encryption_key=data_encryption_key, keyring_trace=keyring_trace, required_flags=required_flags ) + new_materials = copy.copy(self) + data_key = _data_key_to_raw_data_key(data_key=data_encryption_key) + new_materials._setattr( # simplify access to copies pylint: disable=protected-access + "data_encryption_key", data_key + ) + new_materials._keyring_trace.append(keyring_trace) # simplify access to copies pylint: disable=protected-access - super(CryptographicMaterials, self).__setattr__("data_encryption_key", data_key) - self._keyring_trace.append(keyring_trace) + return new_materials @property def keyring_trace(self): @@ -220,7 +231,8 @@ def __init__( if encryption_context is None: raise TypeError("encryption_context must not be None") - if data_encryption_key is None and encrypted_data_keys is not None: + if data_encryption_key is None and encrypted_data_keys: + # If data_encryption_key is not set, encrypted_data_keys MUST be either None or empty raise TypeError("encrypted_data_keys cannot be provided without data_encryption_key") if encrypted_data_keys is None: @@ -236,6 +248,18 @@ def __init__( self._setattr("_encrypted_data_keys", encrypted_data_keys) attr.validate(self) + def __copy__(self): + # type: () -> EncryptionMaterials + """Do a shallow copy of this instance.""" + return EncryptionMaterials( + algorithm=self.algorithm, + data_encryption_key=self.data_encryption_key, + encrypted_data_keys=copy.copy(self._encrypted_data_keys), + encryption_context=self.encryption_context.copy(), + signing_key=self.signing_key, + keyring_trace=copy.copy(self._keyring_trace), + ) + @property def encrypted_data_keys(self): # type: () -> Tuple[EncryptedDataKey] @@ -263,35 +287,37 @@ def is_complete(self): return True - def add_data_encryption_key(self, data_encryption_key, keyring_trace): - # type: (Union[DataKey, RawDataKey], KeyringTrace) -> None - """Add a plaintext data encryption key. + def with_data_encryption_key(self, data_encryption_key, keyring_trace): + # type: (Union[DataKey, RawDataKey], KeyringTrace) -> EncryptionMaterials + """Get new encryption materials that also include this data encryption key. .. versionadded:: 1.5.0 :param RawDataKey data_encryption_key: Data encryption key :param KeyringTrace keyring_trace: Trace of actions that a keyring performed while getting this data encryption key + :rtype: EncryptionMaterials :raises AttributeError: if data encryption key is already set :raises InvalidKeyringTraceError: if keyring trace does not match generate action :raises InvalidKeyringTraceError: if keyring trace does not match data key provider :raises InvalidDataKeyError: if data key length does not match algorithm suite """ - self._add_data_encryption_key( + return self._with_data_encryption_key( data_encryption_key=data_encryption_key, keyring_trace=keyring_trace, required_flags={KeyringTraceFlag.GENERATED_DATA_KEY}, ) - def add_encrypted_data_key(self, encrypted_data_key, keyring_trace): - # type: (EncryptedDataKey, KeyringTrace) -> None - """Add an encrypted data key with corresponding keyring trace. + def with_encrypted_data_key(self, encrypted_data_key, keyring_trace): + # type: (EncryptedDataKey, KeyringTrace) -> EncryptionMaterials + """Get new encryption materials that also include this encrypted data key with corresponding keyring trace. .. versionadded:: 1.5.0 :param EncryptedDataKey encrypted_data_key: Encrypted data key to add :param KeyringTrace keyring_trace: Trace of actions that a keyring performed while getting this encrypted data key + :rtype: EncryptionMaterials :raises AttributeError: if data encryption key is not set :raises InvalidKeyringTraceError: if keyring trace does not match generate action :raises InvalidKeyringTraceError: if keyring trace does not match data key encryptor @@ -302,19 +328,30 @@ def add_encrypted_data_key(self, encrypted_data_key, keyring_trace): if KeyringTraceFlag.ENCRYPTED_DATA_KEY not in keyring_trace.flags: raise InvalidKeyringTraceError("Keyring flags do not match action.") - if keyring_trace.wrapping_key != encrypted_data_key.key_provider: + if not all( + ( + keyring_trace.wrapping_key.provider_id == encrypted_data_key.key_provider.provider_id, + keyring_trace.wrapping_key.key_name == encrypted_data_key.key_provider.key_name, + ) + ): raise InvalidKeyringTraceError("Keyring trace does not match data key encryptor.") - self._encrypted_data_keys.append(encrypted_data_key) - self._keyring_trace.append(keyring_trace) + new_materials = copy.copy(self) - def add_signing_key(self, signing_key): - # type: (bytes) -> None - """Add a signing key. + new_materials._encrypted_data_keys.append( # simplify access to copies pylint: disable=protected-access + encrypted_data_key + ) + new_materials._keyring_trace.append(keyring_trace) # simplify access to copies pylint: disable=protected-access + return new_materials + + def with_signing_key(self, signing_key): + # type: (bytes) -> EncryptionMaterials + """Get new encryption materials that also include this signing key. .. versionadded:: 1.5.0 :param bytes signing_key: Signing key + :rtype: EncryptionMaterials :raises AttributeError: if signing key is already set :raises SignatureKeyError: if algorithm suite does not support signing keys """ @@ -324,10 +361,14 @@ def add_signing_key(self, signing_key): if self.algorithm.signing_algorithm_info is None: raise SignatureKeyError("Algorithm suite does not support signing keys.") + new_materials = copy.copy(self) + # Verify that the signing key matches the algorithm - Signer.from_key_bytes(algorithm=self.algorithm, key_bytes=signing_key) + Signer.from_key_bytes(algorithm=new_materials.algorithm, key_bytes=signing_key) - self._setattr("signing_key", signing_key) + new_materials._setattr("signing_key", signing_key) # simplify access to copies pylint: disable=protected-access + + return new_materials @attr.s(hash=False) @@ -401,6 +442,17 @@ def __init__( self._setattr("verification_key", verification_key) attr.validate(self) + def __copy__(self): + # type: () -> DecryptionMaterials + """Do a shallow copy of this instance.""" + return DecryptionMaterials( + algorithm=self.algorithm, + data_encryption_key=self.data_encryption_key, + encryption_context=copy.copy(self.encryption_context), + verification_key=self.verification_key, + keyring_trace=copy.copy(self._keyring_trace), + ) + @property def is_complete(self): # type: () -> bool @@ -425,15 +477,16 @@ def data_key(self): """Backwards-compatible shim for access to data key.""" return self.data_encryption_key - def add_data_encryption_key(self, data_encryption_key, keyring_trace): - # type: (Union[DataKey, RawDataKey], KeyringTrace) -> None - """Add a plaintext data encryption key. + def with_data_encryption_key(self, data_encryption_key, keyring_trace): + # type: (Union[DataKey, RawDataKey], KeyringTrace) -> DecryptionMaterials + """Get new decryption materials that also include this data encryption key. .. versionadded:: 1.5.0 :param RawDataKey data_encryption_key: Data encryption key :param KeyringTrace keyring_trace: Trace of actions that a keyring performed while getting this data encryption key + :rtype: DecryptionMaterials :raises AttributeError: if data encryption key is already set :raises InvalidKeyringTraceError: if keyring trace does not match decrypt action :raises InvalidKeyringTraceError: if keyring trace does not match data key provider @@ -442,19 +495,20 @@ def add_data_encryption_key(self, data_encryption_key, keyring_trace): if self.algorithm is None: raise AttributeError("Algorithm is not set") - self._add_data_encryption_key( + return self._with_data_encryption_key( data_encryption_key=data_encryption_key, keyring_trace=keyring_trace, required_flags={KeyringTraceFlag.DECRYPTED_DATA_KEY}, ) - def add_verification_key(self, verification_key): - # type: (bytes) -> None - """Add a verification key. + def with_verification_key(self, verification_key): + # type: (bytes) -> DecryptionMaterials + """Get new decryption materials that also include this verification key. .. versionadded:: 1.5.0 :param bytes verification_key: Verification key + :rtype: DecryptionMaterials """ if self.verification_key is not None: raise AttributeError("Verification key is already set.") @@ -462,7 +516,13 @@ def add_verification_key(self, verification_key): if self.algorithm.signing_algorithm_info is None: raise SignatureKeyError("Algorithm suite does not support signing keys.") + new_materials = copy.copy(self) + # Verify that the verification key matches the algorithm - Verifier.from_key_bytes(algorithm=self.algorithm, key_bytes=verification_key) + Verifier.from_key_bytes(algorithm=new_materials.algorithm, key_bytes=verification_key) - self._setattr("verification_key", verification_key) + new_materials._setattr( # simplify access to copies pylint: disable=protected-access + "verification_key", verification_key + ) + + return new_materials diff --git a/src/aws_encryption_sdk/structures.py b/src/aws_encryption_sdk/structures.py index 35eab24e6..ea0af94e9 100644 --- a/src/aws_encryption_sdk/structures.py +++ b/src/aws_encryption_sdk/structures.py @@ -15,7 +15,7 @@ import attr import six -from attr.validators import deep_iterable, deep_mapping, instance_of +from attr.validators import deep_iterable, deep_mapping, instance_of, optional from aws_encryption_sdk.identifiers import Algorithm, ContentType, KeyringTraceFlag, ObjectType, SerializationVersion from aws_encryption_sdk.internal.str_ops import to_bytes, to_str @@ -31,12 +31,39 @@ class MasterKeyInfo(object): """Contains information necessary to identify a Master Key. + .. notice:: + + The only keyring or master key that should need to set ``key_name`` is the Raw AES keyring/master key. + For all other keyrings and master keys, ``key_info`` and ``key_name`` should always be the same. + + + .. versionadded:: 1.5.0 + ``key_name`` + :param str provider_id: MasterKey provider_id value :param bytes key_info: MasterKey key_info value + :param bytes key_name: Key name if different than key_info (optional) """ provider_id = attr.ib(hash=True, validator=instance_of((six.string_types, bytes)), converter=to_str) key_info = attr.ib(hash=True, validator=instance_of((six.string_types, bytes)), converter=to_bytes) + key_name = attr.ib( + hash=True, default=None, validator=optional(instance_of((six.string_types, bytes))), converter=to_bytes + ) + + def __attrs_post_init__(self): + """Set ``key_name`` if not already set.""" + if self.key_name is None: + self.key_name = self.key_info + + @property + def key_namespace(self): + """Access the key namespace value (previously, provider ID). + + .. versionadded:: 1.5.0 + + """ + return self.provider_id @attr.s(hash=True) diff --git a/test/functional/keyrings/aws_kms/test_aws_kms.py b/test/functional/keyrings/aws_kms/test_aws_kms.py index 7a5b50d51..c84174f69 100644 --- a/test/functional/keyrings/aws_kms/test_aws_kms.py +++ b/test/functional/keyrings/aws_kms/test_aws_kms.py @@ -79,6 +79,7 @@ def test_aws_kms_single_cmk_keyring_on_encrypt_existing_data_key(fake_generator) result_materials = keyring.on_encrypt(initial_materials) + assert result_materials is not initial_materials assert result_materials.data_encryption_key is not None assert len(result_materials.encrypted_data_keys) == 1 @@ -149,6 +150,7 @@ def test_aws_kms_single_cmk_keyring_on_decrypt_single_cmk(fake_generator): decryption_materials=initial_decryption_materials, encrypted_data_keys=encryption_materials.encrypted_data_keys ) + assert result_materials is not initial_decryption_materials assert result_materials.data_encryption_key is not None generator_flags = _matching_flags( @@ -243,6 +245,7 @@ def test_aws_kms_discovery_keyring_on_encrypt(): result_materials = keyring.on_encrypt(initial_materials) + assert result_materials is initial_materials assert len(result_materials.encrypted_data_keys) == 0 @@ -268,6 +271,7 @@ def test_aws_kms_discovery_keyring_on_decrypt(encryption_materials_for_discovery decryption_materials=initial_decryption_materials, encrypted_data_keys=encryption_materials.encrypted_data_keys ) + assert result_materials is not initial_decryption_materials assert result_materials.data_encryption_key is not None generator_flags = _matching_flags( diff --git a/test/unit/internal/formatting/test_serialize.py b/test/unit/internal/formatting/test_serialize.py index 8dbe9bd05..7a4063472 100644 --- a/test/unit/internal/formatting/test_serialize.py +++ b/test/unit/internal/formatting/test_serialize.py @@ -325,7 +325,9 @@ def test_serialize_wrapped_key_symmetric(self): ) assert test == EncryptedDataKey( key_provider=MasterKeyInfo( - provider_id=VALUES["provider_id"], key_info=VALUES["wrapped_keys"]["serialized"]["key_info"] + provider_id=self.mock_key_provider.provider_id, + key_info=VALUES["wrapped_keys"]["serialized"]["key_info"], + key_name=VALUES["wrapped_keys"]["raw"]["key_info"], ), encrypted_data_key=VALUES["wrapped_keys"]["serialized"]["key_ciphertext"], ) diff --git a/test/unit/keyrings/raw/test_raw_aes.py b/test/unit/keyrings/raw/test_raw_aes.py index 33a882918..1f4322e09 100644 --- a/test/unit/keyrings/raw/test_raw_aes.py +++ b/test/unit/keyrings/raw/test_raw_aes.py @@ -16,10 +16,10 @@ import mock import pytest -from pytest_mock import mocker # noqa pylint: disable=unused-import import aws_encryption_sdk.key_providers.raw import aws_encryption_sdk.keyrings.raw +from aws_encryption_sdk.exceptions import EncryptKeyError from aws_encryption_sdk.identifiers import Algorithm, KeyringTraceFlag, WrappingAlgorithm from aws_encryption_sdk.internal.crypto.wrapping_keys import WrappingKey from aws_encryption_sdk.keyrings.base import Keyring @@ -67,6 +67,12 @@ def patch_decrypt_on_wrapping_key(mocker): return WrappingKey.decrypt +@pytest.fixture +def patch_encrypt_on_wrapping_key(mocker): + mocker.patch.object(WrappingKey, "encrypt") + return WrappingKey.encrypt + + @pytest.fixture def patch_os_urandom(mocker): mocker.patch.object(os, "urandom") @@ -127,10 +133,18 @@ def test_keyring_trace_on_encrypt_when_data_encryption_key_given(raw_aes_keyring test = test_raw_aes_keyring.on_encrypt(encryption_materials=get_encryption_materials_with_data_encryption_key()) - for keyring_trace in test.keyring_trace: - if keyring_trace.wrapping_key.key_info == _KEY_ID: - # Check keyring trace does not contain KeyringTraceFlag.GENERATED_DATA_KEY - assert KeyringTraceFlag.GENERATED_DATA_KEY not in keyring_trace.flags + trace_entries = [entry for entry in test.keyring_trace if entry.wrapping_key == raw_aes_keyring._key_provider] + assert len(trace_entries) == 1 + + generate_traces = [entry for entry in trace_entries if entry.flags == {KeyringTraceFlag.GENERATED_DATA_KEY}] + assert len(generate_traces) == 0 + + encrypt_traces = [ + entry + for entry in trace_entries + if entry.flags == {KeyringTraceFlag.ENCRYPTED_DATA_KEY, KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT} + ] + assert len(encrypt_traces) == 1 def test_on_encrypt_when_data_encryption_key_not_given(raw_aes_keyring): @@ -146,24 +160,29 @@ def test_on_encrypt_when_data_encryption_key_not_given(raw_aes_keyring): # Check if data key is generated assert test.data_encryption_key is not None - generated_flag_count = 0 - encrypted_flag_count = 0 + trace_entries = [entry for entry in test.keyring_trace if entry.wrapping_key == raw_aes_keyring._key_provider] + assert len(trace_entries) == 2 - for keyring_trace in test.keyring_trace: - if ( - keyring_trace.wrapping_key.key_info == _KEY_ID - and KeyringTraceFlag.GENERATED_DATA_KEY in keyring_trace.flags - ): - # Check keyring trace contains KeyringTraceFlag.GENERATED_DATA_KEY - generated_flag_count += 1 - if KeyringTraceFlag.ENCRYPTED_DATA_KEY in keyring_trace.flags: - encrypted_flag_count += 1 + generate_traces = [entry for entry in trace_entries if entry.flags == {KeyringTraceFlag.GENERATED_DATA_KEY}] + assert len(generate_traces) == 1 - assert generated_flag_count == 1 + encrypt_traces = [ + entry + for entry in trace_entries + if entry.flags == {KeyringTraceFlag.ENCRYPTED_DATA_KEY, KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT} + ] + assert len(encrypt_traces) == 1 assert len(test.encrypted_data_keys) == original_number_of_encrypted_data_keys + 1 - assert encrypted_flag_count == 1 + +def test_on_encrypt_cannot_encrypt(patch_encrypt_on_wrapping_key, raw_aes_keyring): + patch_encrypt_on_wrapping_key.side_effect = Exception("ENCRYPT FAIL") + + with pytest.raises(EncryptKeyError) as excinfo: + raw_aes_keyring.on_encrypt(get_encryption_materials_without_data_encryption_key()) + + excinfo.match("Raw AES keyring unable to encrypt data key") @pytest.mark.parametrize( @@ -179,16 +198,15 @@ def test_on_decrypt_when_data_key_given(raw_aes_keyring, decryption_materials, e assert not patch_decrypt_on_wrapping_key.called -def test_keyring_trace_on_decrypt_when_data_key_given(raw_aes_keyring): +def test_on_decrypt_keyring_trace_when_data_key_given(raw_aes_keyring): test_raw_aes_keyring = raw_aes_keyring test = test_raw_aes_keyring.on_decrypt( decryption_materials=get_decryption_materials_with_data_encryption_key(), encrypted_data_keys=[_ENCRYPTED_DATA_KEY_AES], ) - for keyring_trace in test.keyring_trace: - if keyring_trace.wrapping_key.key_info == _KEY_ID: - # Check keyring trace does not contain KeyringTraceFlag.DECRYPTED_DATA_KEY - assert KeyringTraceFlag.DECRYPTED_DATA_KEY not in keyring_trace.flags + + trace_entries = [entry for entry in test.keyring_trace if entry.wrapping_key == raw_aes_keyring._key_provider] + assert len(trace_entries) == 0 @pytest.mark.parametrize( @@ -206,9 +224,8 @@ def test_on_decrypt_when_data_key_and_edk_not_provided( test = test_raw_aes_keyring.on_decrypt(decryption_materials=decryption_materials, encrypted_data_keys=edk) assert not patch_decrypt_on_wrapping_key.called - for keyring_trace in test.keyring_trace: - if keyring_trace.wrapping_key.key_info == _KEY_ID: - assert KeyringTraceFlag.DECRYPTED_DATA_KEY not in keyring_trace.flags + trace_entries = [entry for entry in test.keyring_trace if entry.wrapping_key == raw_aes_keyring._key_provider] + assert len(trace_entries) == 0 assert test.data_encryption_key is None @@ -225,23 +242,38 @@ def test_on_decrypt_when_data_key_not_provided_and_edk_provided(raw_aes_keyring, ) -def test_keyring_trace_when_data_key_not_provided_and_edk_provided(raw_aes_keyring): +def test_on_decrypt_keyring_trace_when_data_key_not_provided_and_edk_provided(raw_aes_keyring): test_raw_aes_keyring = raw_aes_keyring test = test_raw_aes_keyring.on_decrypt( decryption_materials=get_decryption_materials_without_data_encryption_key(), encrypted_data_keys=[_ENCRYPTED_DATA_KEY_AES], ) - decrypted_flag_count = 0 - for keyring_trace in test.keyring_trace: - if KeyringTraceFlag.DECRYPTED_DATA_KEY in keyring_trace.flags: - decrypted_flag_count += 1 + trace_entries = [entry for entry in test.keyring_trace if entry.wrapping_key == raw_aes_keyring._key_provider] + assert len(trace_entries) == 1 - assert decrypted_flag_count == 1 + decrypt_traces = [ + entry + for entry in trace_entries + if entry.flags == {KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT} + ] + assert len(decrypt_traces) == 1 -def test_error_when_data_key_not_generated(patch_os_urandom): +def test_on_decrypt_continues_through_edks_on_failure(raw_aes_keyring, patch_decrypt_on_wrapping_key): + patch_decrypt_on_wrapping_key.side_effect = (Exception("DECRYPT FAIL"), _DATA_KEY) + + test = raw_aes_keyring.on_decrypt( + decryption_materials=get_decryption_materials_without_data_encryption_key(), + encrypted_data_keys=(_ENCRYPTED_DATA_KEY_AES, _ENCRYPTED_DATA_KEY_AES), + ) + + assert test.data_encryption_key is not None + assert patch_decrypt_on_wrapping_key.call_count == 2 + + +def test_generate_data_key_error_when_data_key_not_generated(patch_os_urandom): patch_os_urandom.side_effect = NotImplementedError with pytest.raises(GenerateKeyError) as exc_info: _generate_data_key( @@ -266,17 +298,20 @@ def test_generate_data_key_keyring_trace(): encryption_context=_ENCRYPTION_CONTEXT, signing_key=_SIGNING_KEY, ) - _generate_data_key( - encryption_materials=encryption_materials_without_data_key, - key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_KEY_ID), + key_provider_info = MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_KEY_ID) + new_materials = _generate_data_key( + encryption_materials=encryption_materials_without_data_key, key_provider=key_provider_info, ) - assert encryption_materials_without_data_key.data_encryption_key.key_provider.provider_id == _PROVIDER_ID - assert encryption_materials_without_data_key.data_encryption_key.key_provider.key_info == _KEY_ID + assert new_materials is not encryption_materials_without_data_key + assert encryption_materials_without_data_key.data_encryption_key is None + assert not encryption_materials_without_data_key.keyring_trace + + assert new_materials.data_encryption_key is not None + assert new_materials.data_encryption_key.key_provider == key_provider_info - generate_flag_count = 0 + trace_entries = [entry for entry in new_materials.keyring_trace if entry.wrapping_key == key_provider_info] + assert len(trace_entries) == 1 - for keyring_trace in encryption_materials_without_data_key.keyring_trace: - if KeyringTraceFlag.GENERATED_DATA_KEY in keyring_trace.flags: - generate_flag_count += 1 - assert generate_flag_count == 1 + generate_traces = [entry for entry in trace_entries if entry.flags == {KeyringTraceFlag.GENERATED_DATA_KEY}] + assert len(generate_traces) == 1 diff --git a/test/unit/keyrings/raw/test_raw_rsa.py b/test/unit/keyrings/raw/test_raw_rsa.py index d9aa7b266..4b0a8d506 100644 --- a/test/unit/keyrings/raw/test_raw_rsa.py +++ b/test/unit/keyrings/raw/test_raw_rsa.py @@ -14,10 +14,10 @@ import pytest from cryptography.hazmat.primitives.asymmetric import rsa -from pytest_mock import mocker # noqa pylint: disable=unused-import import aws_encryption_sdk.key_providers.raw import aws_encryption_sdk.keyrings.raw +from aws_encryption_sdk.exceptions import EncryptKeyError from aws_encryption_sdk.identifiers import KeyringTraceFlag, WrappingAlgorithm from aws_encryption_sdk.internal.crypto.wrapping_keys import WrappingKey from aws_encryption_sdk.keyrings.base import Keyring @@ -26,6 +26,7 @@ from ...unit_test_utils import ( _BACKEND, _DATA_KEY, + _ENCRYPTED_DATA_KEY_AES, _ENCRYPTED_DATA_KEY_RSA, _ENCRYPTION_CONTEXT, _KEY_ID, @@ -123,15 +124,28 @@ def test_on_encrypt_when_data_encryption_key_given(raw_rsa_keyring, patch_genera assert not patch_generate_data_key.called -def test_keyring_trace_on_encrypt_when_data_encryption_key_given(raw_rsa_keyring): - test_raw_rsa_keyring = raw_rsa_keyring +def test_on_encrypt_no_public_key(raw_rsa_keyring): + raw_rsa_keyring._public_wrapping_key = None + + with pytest.raises(EncryptKeyError) as excinfo: + raw_rsa_keyring.on_encrypt(encryption_materials=get_encryption_materials_without_data_encryption_key()) + + excinfo.match("Raw RSA keyring unable to encrypt data key: no public key available") + - test = test_raw_rsa_keyring.on_encrypt(encryption_materials=get_encryption_materials_with_data_encryption_key()) +def test_on_encrypt_keyring_trace_when_data_encryption_key_given(raw_rsa_keyring): + materials = get_encryption_materials_with_data_encryption_key() + test = raw_rsa_keyring.on_encrypt(encryption_materials=materials) + assert test is not materials - for keyring_trace in test.keyring_trace: - if keyring_trace.wrapping_key.key_info == _KEY_ID: - # Check keyring trace does not contain KeyringTraceFlag.GENERATED_DATA_KEY - assert KeyringTraceFlag.GENERATED_DATA_KEY not in keyring_trace.flags + trace_entries = [entry for entry in test.keyring_trace if entry.wrapping_key == raw_rsa_keyring._key_provider] + assert len(trace_entries) == 1 + + encrypt_traces = [entry for entry in trace_entries if entry.flags == {KeyringTraceFlag.ENCRYPTED_DATA_KEY}] + assert len(encrypt_traces) == 1 + + generate_traces = [entry for entry in trace_entries if entry.flags == {KeyringTraceFlag.GENERATED_DATA_KEY}] + assert len(generate_traces) == 0 def test_on_encrypt_when_data_encryption_key_not_given(raw_rsa_keyring): @@ -143,27 +157,28 @@ def test_on_encrypt_when_data_encryption_key_not_given(raw_rsa_keyring): test = test_raw_rsa_keyring.on_encrypt(encryption_materials=get_encryption_materials_without_data_encryption_key()) - # Check if data key is generated - assert test.data_encryption_key is not None + trace_entries = [entry for entry in test.keyring_trace if entry.wrapping_key == raw_rsa_keyring._key_provider] + assert len(trace_entries) == 2 - generated_flag_count = 0 - encrypted_flag_count = 0 + encrypt_traces = [entry for entry in trace_entries if entry.flags == {KeyringTraceFlag.ENCRYPTED_DATA_KEY}] + assert len(encrypt_traces) == 1 - for keyring_trace in test.keyring_trace: - if ( - keyring_trace.wrapping_key.key_info == _KEY_ID - and KeyringTraceFlag.GENERATED_DATA_KEY in keyring_trace.flags - ): - # Check keyring trace contains KeyringTraceFlag.GENERATED_DATA_KEY - generated_flag_count += 1 - if KeyringTraceFlag.ENCRYPTED_DATA_KEY in keyring_trace.flags: - encrypted_flag_count += 1 + generate_traces = [entry for entry in trace_entries if entry.flags == {KeyringTraceFlag.GENERATED_DATA_KEY}] + assert len(generate_traces) == 1 - assert generated_flag_count == 1 + assert test.data_encryption_key.data_key is not None assert len(test.encrypted_data_keys) == original_number_of_encrypted_data_keys + 1 - assert encrypted_flag_count == 1 + +def test_on_encrypt_cannot_encrypt(raw_rsa_keyring, mocker): + encrypt_patch = mocker.patch.object(raw_rsa_keyring._public_wrapping_key, "encrypt") + encrypt_patch.side_effect = Exception("ENCRYPT FAIL") + + with pytest.raises(EncryptKeyError) as excinfo: + raw_rsa_keyring.on_encrypt(encryption_materials=get_encryption_materials_without_data_encryption_key()) + + excinfo.match("Raw RSA keyring unable to encrypt data key") def test_on_decrypt_when_data_key_given(raw_rsa_keyring, patch_decrypt_on_wrapping_key): @@ -175,16 +190,23 @@ def test_on_decrypt_when_data_key_given(raw_rsa_keyring, patch_decrypt_on_wrappi assert not patch_decrypt_on_wrapping_key.called -def test_keyring_trace_on_decrypt_when_data_key_given(raw_rsa_keyring): +def test_on_decrypt_no_private_key(raw_rsa_keyring): + raw_rsa_keyring._private_wrapping_key = None + + materials = get_decryption_materials_without_data_encryption_key() + test = raw_rsa_keyring.on_decrypt(decryption_materials=materials, encrypted_data_keys=[_ENCRYPTED_DATA_KEY_RSA],) + + assert test is materials + + +def test_on_decrypt_keyring_trace_when_data_key_given(raw_rsa_keyring): test_raw_rsa_keyring = raw_rsa_keyring test = test_raw_rsa_keyring.on_decrypt( decryption_materials=get_decryption_materials_with_data_encryption_key(), encrypted_data_keys=[_ENCRYPTED_DATA_KEY_RSA], ) - for keyring_trace in test.keyring_trace: - if keyring_trace.wrapping_key.key_info == _KEY_ID: - # Check keyring trace does not contain KeyringTraceFlag.DECRYPTED_DATA_KEY - assert KeyringTraceFlag.DECRYPTED_DATA_KEY not in keyring_trace.flags + trace_entries = [entry for entry in test.keyring_trace if entry.wrapping_key == raw_rsa_keyring._key_provider] + assert len(trace_entries) == 0 def test_on_decrypt_when_data_key_and_edk_not_provided(raw_rsa_keyring, patch_decrypt_on_wrapping_key): @@ -195,8 +217,21 @@ def test_on_decrypt_when_data_key_and_edk_not_provided(raw_rsa_keyring, patch_de ) assert not patch_decrypt_on_wrapping_key.called - for keyring_trace in test.keyring_trace: - assert KeyringTraceFlag.DECRYPTED_DATA_KEY not in keyring_trace.flags + trace_entries = [entry for entry in test.keyring_trace if entry.wrapping_key == raw_rsa_keyring._key_provider] + assert len(trace_entries) == 0 + + assert test.data_encryption_key is None + + +def test_on_decrypt_when_data_key_not_provided_and_no_know_edks(raw_rsa_keyring, mocker): + patched_wrapping_key_decrypt = mocker.patch.object(raw_rsa_keyring._private_wrapping_key, "decrypt") + + test = raw_rsa_keyring.on_decrypt( + decryption_materials=get_decryption_materials_without_data_encryption_key(), + encrypted_data_keys=[_ENCRYPTED_DATA_KEY_AES], + ) + + assert not patched_wrapping_key_decrypt.called assert test.data_encryption_key is None @@ -210,9 +245,8 @@ def test_on_decrypt_when_data_key_not_provided_and_edk_not_in_keyring(raw_rsa_ke ) assert not patch_decrypt_on_wrapping_key.called - for keyring_trace in test.keyring_trace: - if keyring_trace.wrapping_key.key_info == _KEY_ID: - assert KeyringTraceFlag.DECRYPTED_DATA_KEY not in keyring_trace.flags + trace_entries = [entry for entry in test.keyring_trace if entry.wrapping_key == raw_rsa_keyring._key_provider] + assert not trace_entries assert test.data_encryption_key is None @@ -230,7 +264,7 @@ def test_on_decrypt_when_data_key_not_provided_and_edk_provided(raw_rsa_keyring, ) -def test_keyring_trace_when_data_key_not_provided_and_edk_provided(raw_rsa_keyring): +def test_on_decrypt_keyring_trace_when_data_key_not_provided_and_edk_provided(raw_rsa_keyring): test_raw_rsa_keyring = raw_rsa_keyring test = test_raw_rsa_keyring.on_decrypt( @@ -239,11 +273,31 @@ def test_keyring_trace_when_data_key_not_provided_and_edk_provided(raw_rsa_keyri encryption_materials=get_encryption_materials_without_data_encryption_key() ).encrypted_data_keys, ) - decrypted_flag_count = 0 - for keyring_trace in test.keyring_trace: - if KeyringTraceFlag.DECRYPTED_DATA_KEY in keyring_trace.flags: - decrypted_flag_count += 1 + trace_entries = [entry for entry in test.keyring_trace if entry.wrapping_key == raw_rsa_keyring._key_provider] + assert len(trace_entries) == 1 + + decrypt_traces = [entry for entry in trace_entries if entry.flags == {KeyringTraceFlag.DECRYPTED_DATA_KEY}] + assert len(decrypt_traces) == 1 - assert decrypted_flag_count == 1 assert test.data_encryption_key is not None + + +def test_on_decrypt_continues_through_edks_on_failure(raw_rsa_keyring, mocker): + patched_wrapping_key_decrypt = mocker.patch.object(raw_rsa_keyring._private_wrapping_key, "decrypt") + patched_wrapping_key_decrypt.side_effect = (Exception("DECRYPT FAIL"), _DATA_KEY) + + test = raw_rsa_keyring.on_decrypt( + decryption_materials=get_decryption_materials_without_data_encryption_key(), + encrypted_data_keys=(_ENCRYPTED_DATA_KEY_RSA, _ENCRYPTED_DATA_KEY_RSA), + ) + + assert patched_wrapping_key_decrypt.call_count == 2 + + trace_entries = [entry for entry in test.keyring_trace if entry.wrapping_key == raw_rsa_keyring._key_provider] + assert len(trace_entries) == 1 + + decrypt_traces = [entry for entry in trace_entries if entry.flags == {KeyringTraceFlag.DECRYPTED_DATA_KEY}] + assert len(decrypt_traces) == 1 + + assert test.data_encryption_key.data_key == _DATA_KEY diff --git a/test/unit/keyrings/test_multi.py b/test/unit/keyrings/test_multi.py index c0bdc78d9..747ef5c37 100644 --- a/test/unit/keyrings/test_multi.py +++ b/test/unit/keyrings/test_multi.py @@ -187,7 +187,11 @@ def test_number_of_encrypted_data_keys_with_generator_and_children(): def test_on_encrypt_when_data_encryption_key_given(mock_generator, mock_child_1, mock_child_2): test_multi_keyring = MultiKeyring(generator=mock_generator, children=[mock_child_1, mock_child_2]) - test_multi_keyring.on_encrypt(encryption_materials=get_encryption_materials_with_data_key()) + initial_materials = get_encryption_materials_with_data_key() + new_materials = test_multi_keyring.on_encrypt(encryption_materials=initial_materials) + + assert new_materials is not initial_materials + for keyring in test_multi_keyring._decryption_keyrings: keyring.on_encrypt.assert_called_once() @@ -208,7 +212,11 @@ def test_on_encrypt_edk_length_when_keyring_generates_but_does_not_encrypt_encry def test_on_decrypt_when_data_encryption_key_given(mock_generator, mock_child_1, mock_child_2): test_multi_keyring = MultiKeyring(generator=mock_generator, children=[mock_child_1, mock_child_2]) - test_multi_keyring.on_decrypt(decryption_materials=get_decryption_materials_with_data_key(), encrypted_data_keys=[]) + initial_materials = get_decryption_materials_with_data_key() + new_materials = test_multi_keyring.on_decrypt(decryption_materials=initial_materials, encrypted_data_keys=[]) + + assert new_materials is initial_materials + for keyring in test_multi_keyring._decryption_keyrings: assert not keyring.on_decrypt.called @@ -238,9 +246,10 @@ def test_no_keyring_called_after_data_encryption_key_added_when_data_encryption_ ) test_multi_keyring = MultiKeyring(generator=mock_generator, children=[mock_child_3, mock_child_1, mock_child_2]) - test_multi_keyring.on_decrypt( - decryption_materials=get_decryption_materials_without_data_key(), encrypted_data_keys=[] - ) + initial_materials = get_decryption_materials_without_data_key() + new_materials = test_multi_keyring.on_decrypt(decryption_materials=initial_materials, encrypted_data_keys=[]) + + assert new_materials is not initial_materials assert mock_generator.on_decrypt.called assert mock_child_3.on_decrypt.called assert not mock_child_1.called diff --git a/test/unit/materials_managers/test_material_managers.py b/test/unit/materials_managers/test_material_managers.py index 499f7ba0d..62314298e 100644 --- a/test/unit/materials_managers/test_material_managers.py +++ b/test/unit/materials_managers/test_material_managers.py @@ -16,7 +16,6 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec from mock import MagicMock -from pytest_mock import mocker # noqa pylint: disable=unused-import from aws_encryption_sdk.exceptions import InvalidDataKeyError, InvalidKeyringTraceError, SignatureKeyError from aws_encryption_sdk.identifiers import AlgorithmSuite, KeyringTraceFlag @@ -104,7 +103,6 @@ def _copy_and_update_kwargs(class_name, mod_kwargs): (EncryptionMaterials, dict(algorithm=None)), (EncryptionMaterials, dict(encryption_context=None)), (EncryptionMaterials, dict(signing_key=u"not bytes or None")), - (EncryptionMaterials, dict(data_encryption_key=_REMOVE)), (DecryptionMaterialsRequest, dict(algorithm=None)), (DecryptionMaterialsRequest, dict(encrypted_data_keys=None)), (DecryptionMaterialsRequest, dict(encryption_context=None)), @@ -123,6 +121,8 @@ def test_attributes_fails(attr_class, invalid_kwargs): ( (CryptographicMaterials, {}), (EncryptionMaterials, {}), + (EncryptionMaterials, dict(data_encryption_key=_REMOVE, encrypted_data_keys=[])), + (EncryptionMaterials, dict(data_encryption_key=_REMOVE, encrypted_data_keys=_REMOVE)), (DecryptionMaterials, {}), (DecryptionMaterials, dict(data_key=_REMOVE, data_encryption_key=_REMOVE)), (DecryptionMaterials, dict(data_key=_REMOVE, data_encryption_key=_RAW_DATA_KEY)), @@ -248,18 +248,19 @@ def test_empty_encrypted_data_keys(): (DecryptionMaterials, KeyringTraceFlag.DECRYPTED_DATA_KEY), ), ) -def test_add_data_encryption_key_success(material_class, flag): +def test_with_data_encryption_key_success(material_class, flag): kwargs = _copy_and_update_kwargs( material_class.__name__, dict(data_encryption_key=_REMOVE, data_key=_REMOVE, encrypted_data_keys=_REMOVE) ) materials = material_class(**kwargs) - materials.add_data_encryption_key( + new_materials = materials.with_data_encryption_key( data_encryption_key=RawDataKey( key_provider=MasterKeyInfo(provider_id="a", key_info=b"b"), data_key=b"1" * ALGORITHM.kdf_input_len ), keyring_trace=KeyringTrace(wrapping_key=MasterKeyInfo(provider_id="a", key_info=b"b"), flags={flag}), ) + assert new_materials is not materials def _add_data_encryption_key_test_cases(): @@ -313,28 +314,29 @@ def _add_data_encryption_key_test_cases(): "material_class, mod_kwargs, data_encryption_key, keyring_trace, exception_type, exception_message", _add_data_encryption_key_test_cases(), ) -def test_add_data_encryption_key_fail( +def test_with_data_encryption_key_fail( material_class, mod_kwargs, data_encryption_key, keyring_trace, exception_type, exception_message ): kwargs = _copy_and_update_kwargs(material_class.__name__, mod_kwargs) materials = material_class(**kwargs) with pytest.raises(exception_type) as excinfo: - materials.add_data_encryption_key(data_encryption_key=data_encryption_key, keyring_trace=keyring_trace) + materials.with_data_encryption_key(data_encryption_key=data_encryption_key, keyring_trace=keyring_trace) excinfo.match(exception_message) -def test_add_encrypted_data_key_success(): +def test_with_encrypted_data_key_success(): kwargs = _copy_and_update_kwargs("EncryptionMaterials", {}) materials = EncryptionMaterials(**kwargs) - materials.add_encrypted_data_key( + new_materials = materials.with_encrypted_data_key( _ENCRYPTED_DATA_KEY, keyring_trace=KeyringTrace( wrapping_key=_ENCRYPTED_DATA_KEY.key_provider, flags={KeyringTraceFlag.ENCRYPTED_DATA_KEY} ), ) + assert new_materials is not materials @pytest.mark.parametrize( @@ -366,21 +368,22 @@ def test_add_encrypted_data_key_success(): ), ), ) -def test_add_encrypted_data_key_fail(mod_kwargs, encrypted_data_key, keyring_trace, exception_type, exception_message): +def test_with_encrypted_data_key_fail(mod_kwargs, encrypted_data_key, keyring_trace, exception_type, exception_message): kwargs = _copy_and_update_kwargs("EncryptionMaterials", mod_kwargs) materials = EncryptionMaterials(**kwargs) with pytest.raises(exception_type) as excinfo: - materials.add_encrypted_data_key(encrypted_data_key=encrypted_data_key, keyring_trace=keyring_trace) + materials.with_encrypted_data_key(encrypted_data_key=encrypted_data_key, keyring_trace=keyring_trace) excinfo.match(exception_message) -def test_add_signing_key_success(): +def test_with_signing_key_success(): kwargs = _copy_and_update_kwargs("EncryptionMaterials", dict(signing_key=_REMOVE)) materials = EncryptionMaterials(**kwargs) - materials.add_signing_key(signing_key=_SIGNING_KEY.key_bytes()) + new_materials = materials.with_signing_key(signing_key=_SIGNING_KEY.key_bytes()) + assert new_materials is not materials @pytest.mark.parametrize( @@ -395,21 +398,22 @@ def test_add_signing_key_success(): ), ), ) -def test_add_signing_key_fail(mod_kwargs, signing_key, exception_type, exception_message): +def test_with_signing_key_fail(mod_kwargs, signing_key, exception_type, exception_message): kwargs = _copy_and_update_kwargs("EncryptionMaterials", mod_kwargs) materials = EncryptionMaterials(**kwargs) with pytest.raises(exception_type) as excinfo: - materials.add_signing_key(signing_key=signing_key) + materials.with_signing_key(signing_key=signing_key) excinfo.match(exception_message) -def test_add_verification_key_success(): +def test_with_verification_key_success(): kwargs = _copy_and_update_kwargs("DecryptionMaterials", dict(verification_key=_REMOVE)) materials = DecryptionMaterials(**kwargs) - materials.add_verification_key(verification_key=_VERIFICATION_KEY.key_bytes()) + new_materials = materials.with_verification_key(verification_key=_VERIFICATION_KEY.key_bytes()) + assert new_materials is not materials @pytest.mark.parametrize( @@ -424,12 +428,12 @@ def test_add_verification_key_success(): ), ), ) -def test_add_verification_key_fail(mod_kwargs, verification_key, exception_type, exception_message): +def test_with_verification_key_fail(mod_kwargs, verification_key, exception_type, exception_message): kwargs = _copy_and_update_kwargs("DecryptionMaterials", mod_kwargs) materials = DecryptionMaterials(**kwargs) with pytest.raises(exception_type) as excinfo: - materials.add_verification_key(verification_key=verification_key) + materials.with_verification_key(verification_key=verification_key) excinfo.match(exception_message) @@ -457,7 +461,9 @@ def test_decryption_materials_is_not_complete(mod_kwargs): def test_encryption_materials_is_complete(): - materials = EncryptionMaterials(**_copy_and_update_kwargs("EncryptionMaterials", {})) + materials = EncryptionMaterials( + **_copy_and_update_kwargs("EncryptionMaterials", dict(encrypted_data_keys=[_ENCRYPTED_DATA_KEY])) + ) assert materials.is_complete @@ -466,6 +472,7 @@ def test_encryption_materials_is_complete(): "mod_kwargs", ( dict(data_encryption_key=_REMOVE, encrypted_data_keys=_REMOVE), + dict(encrypted_data_keys=[]), dict(encrypted_data_keys=_REMOVE), dict(signing_key=_REMOVE), ), diff --git a/test/unit/unit_test_utils.py b/test/unit/unit_test_utils.py index ddde3e975..9063badc6 100644 --- a/test/unit/unit_test_utils.py +++ b/test/unit/unit_test_utils.py @@ -33,6 +33,7 @@ _ENCRYPTION_CONTEXT = {"encryption": "context", "values": "here"} _PROVIDER_ID = "Random Raw Keys" +_EXISTING_KEY_ID = b"pre-seeded key id" _KEY_ID = b"5325b043-5843-4629-869c-64794af77ada" _WRAPPING_KEY = b"\xeby-\x80A6\x15rA8\x83#,\xe4\xab\xac`\xaf\x99Z\xc1\xce\xdb\xb6\x0f\xb7\x805\xb2\x14J3" _SIGNING_KEY = b"aws-crypto-public-key" @@ -98,7 +99,7 @@ def on_encrypt(self, encryption_materials): data_encryption_key = RawDataKey( key_provider=key_provider, data_key=os.urandom(encryption_materials.algorithm.kdf_input_len) ) - encryption_materials.add_data_encryption_key( + encryption_materials = encryption_materials.with_data_encryption_key( data_encryption_key=data_encryption_key, keyring_trace=KeyringTrace(wrapping_key=key_provider, flags={KeyringTraceFlag.GENERATED_DATA_KEY}), ) @@ -113,14 +114,14 @@ def get_encryption_materials_with_data_key(): return EncryptionMaterials( algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, data_encryption_key=RawDataKey( - key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_KEY_ID), + key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), data_key=b'*!\xa1"^-(\xf3\x105\x05i@B\xc2\xa2\xb7\xdd\xd5\xd5\xa9\xddm\xfae\xa8\\$\xf9d\x1e(', ), encryption_context=_ENCRYPTION_CONTEXT, signing_key=_SIGNING_KEY, keyring_trace=[ KeyringTrace( - wrapping_key=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_KEY_ID), + wrapping_key=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), flags={KeyringTraceFlag.GENERATED_DATA_KEY}, ) ], @@ -131,14 +132,14 @@ def get_encryption_materials_with_data_encryption_key(): return EncryptionMaterials( algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, data_encryption_key=RawDataKey( - key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=b"5430b043-5843-4629-869c-64794af77ada"), + key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), data_key=b'*!\xa1"^-(\xf3\x105\x05i@B\xc2\xa2\xb7\xdd\xd5\xd5\xa9\xddm\xfae\xa8\\$\xf9d\x1e(', ), encryption_context=_ENCRYPTION_CONTEXT, signing_key=_SIGNING_KEY, keyring_trace=[ KeyringTrace( - wrapping_key=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=b"5430b043-5843-4629-869c-64794af77ada"), + wrapping_key=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), flags={KeyringTraceFlag.GENERATED_DATA_KEY}, ) ], @@ -157,12 +158,12 @@ def get_encryption_materials_with_encrypted_data_key(): return EncryptionMaterials( algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, data_encryption_key=RawDataKey( - key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_KEY_ID), + key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), data_key=b'*!\xa1"^-(\xf3\x105\x05i@B\xc2\xa2\xb7\xdd\xd5\xd5\xa9\xddm\xfae\xa8\\$\xf9d\x1e(', ), encrypted_data_keys=[ EncryptedDataKey( - key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_KEY_ID), + key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), encrypted_data_key=b"\xde^\x97\x7f\x84\xe9\x9e\x98\xd0\xe2\xf8\xd5\xcb\xe9\x7f.}\x87\x16,\x11n#\xc8p" b"\xdb\xbf\x94\x86*Q\x06\xd2\xf5\xdah\x08\xa4p\x81\xf7\xf4G\x07FzE\xde", ) @@ -171,7 +172,7 @@ def get_encryption_materials_with_encrypted_data_key(): signing_key=_SIGNING_KEY, keyring_trace=[ KeyringTrace( - wrapping_key=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_KEY_ID), + wrapping_key=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), flags={KeyringTraceFlag.GENERATED_DATA_KEY, KeyringTraceFlag.ENCRYPTED_DATA_KEY}, ) ], @@ -182,7 +183,7 @@ def get_encryption_materials_with_encrypted_data_key_aes(): return EncryptionMaterials( algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, data_encryption_key=RawDataKey( - key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_KEY_ID), + key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), data_key=b'*!\xa1"^-(\xf3\x105\x05i@B\xc2\xa2\xb7\xdd\xd5\xd5\xa9\xddm\xfae\xa8\\$\xf9d\x1e(', ), encrypted_data_keys=[_ENCRYPTED_DATA_KEY_AES], @@ -190,7 +191,7 @@ def get_encryption_materials_with_encrypted_data_key_aes(): signing_key=_SIGNING_KEY, keyring_trace=[ KeyringTrace( - wrapping_key=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_KEY_ID), + wrapping_key=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), flags={KeyringTraceFlag.GENERATED_DATA_KEY, KeyringTraceFlag.ENCRYPTED_DATA_KEY}, ) ], @@ -217,14 +218,14 @@ def get_decryption_materials_with_data_key(): return DecryptionMaterials( algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, data_encryption_key=RawDataKey( - key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_KEY_ID), + key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), data_key=b'*!\xa1"^-(\xf3\x105\x05i@B\xc2\xa2\xb7\xdd\xd5\xd5\xa9\xddm\xfae\xa8\\$\xf9d\x1e(', ), encryption_context=_ENCRYPTION_CONTEXT, verification_key=b"ex_verification_key", keyring_trace=[ KeyringTrace( - wrapping_key=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_KEY_ID), + wrapping_key=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), flags={KeyringTraceFlag.DECRYPTED_DATA_KEY}, ) ], @@ -235,14 +236,14 @@ def get_decryption_materials_with_data_encryption_key(): return DecryptionMaterials( algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, data_encryption_key=RawDataKey( - key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=b"5430b043-5843-4629-869c-64794af77ada"), + key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), data_key=b'*!\xa1"^-(\xf3\x105\x05i@B\xc2\xa2\xb7\xdd\xd5\xd5\xa9\xddm\xfae\xa8\\$\xf9d\x1e(', ), encryption_context=_ENCRYPTION_CONTEXT, verification_key=b"ex_verification_key", keyring_trace=[ KeyringTrace( - wrapping_key=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=b"5430b043-5843-4629-869c-64794af77ada"), + wrapping_key=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=_EXISTING_KEY_ID), flags={KeyringTraceFlag.DECRYPTED_DATA_KEY}, ) ],