Skip to content

Commit 88130e7

Browse files
committed
PR Feedback
1 parent 2280fac commit 88130e7

File tree

8 files changed

+110
-81
lines changed

8 files changed

+110
-81
lines changed

key-value/key-value-aio/src/key_value/aio/wrappers/encryption/base.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44
from typing import Any, SupportsFloat
55

66
from key_value.shared.errors.key_value import SerializationError
7-
from key_value.shared.errors.wrappers.encryption import DecryptionError
7+
from key_value.shared.errors.wrappers.encryption import CorruptedEncryptionDataError, DecryptionError
88
from typing_extensions import override
99

1010
from key_value.aio.protocols.key_value import AsyncKeyValue
1111
from key_value.aio.wrappers.base import BaseWrapper
1212

13-
# Special keys used to store encrypted data
1413
_ENCRYPTED_DATA_KEY = "__encrypted_data__"
1514
_ENCRYPTION_VERSION_KEY = "__encryption_version__"
16-
_ENCRYPTION_VERSION = 1
15+
1716

1817
EncryptionFn = Callable[[bytes], bytes]
1918
DecryptionFn = Callable[[bytes, int], bytes]
@@ -26,18 +25,9 @@ class EncryptionError(Exception):
2625
class BaseEncryptionWrapper(BaseWrapper):
2726
"""Wrapper that encrypts values before storing and decrypts on retrieval.
2827
29-
This wrapper encrypts the JSON-serialized value using Fernet (symmetric encryption)
28+
This wrapper encrypts the JSON-serialized value using a custom encryption function
3029
and stores it as a base64-encoded string within a special key in the dictionary.
3130
This allows encryption while maintaining the dict[str, Any] interface.
32-
33-
The encrypted format looks like:
34-
{
35-
"__encrypted_data__": "base64-encoded-encrypted-data",
36-
"__encryption_version__": 1
37-
}
38-
39-
Note: The encryption key must be kept secret and secure. If the key is lost,
40-
encrypted data cannot be recovered.
4131
"""
4232

4333
def __init__(
@@ -69,22 +59,31 @@ def __init__(
6959
super().__init__()
7060

7161
def _encrypt_value(self, value: dict[str, Any]) -> dict[str, Any]:
72-
"""Encrypt a value into the encrypted format."""
62+
"""Encrypt a value into the encrypted format.
63+
64+
The encrypted format looks like:
65+
{
66+
"__encrypted_data__": "base64-encoded-encrypted-data",
67+
"__encryption_version__": 1
68+
}
69+
"""
7370

7471
# Serialize to JSON
7572
try:
7673
json_str: str = json.dumps(value, separators=(",", ":"))
74+
75+
json_bytes: bytes = json_str.encode(encoding="utf-8")
7776
except (json.JSONDecodeError, TypeError) as e:
7877
msg: str = f"Failed to serialize object to JSON: {e}"
7978
raise SerializationError(msg) from e
8079

81-
json_bytes: bytes = json_str.encode(encoding="utf-8")
82-
83-
# Encrypt with Fernet
84-
encrypted_bytes: bytes = self._encryption_fn(json_bytes)
80+
try:
81+
encrypted_bytes: bytes = self._encryption_fn(json_bytes)
8582

86-
# Encode to base64 for storage in dict (though Fernet output is already base64)
87-
base64_str: str = base64.b64encode(encrypted_bytes).decode(encoding="ascii")
83+
base64_str: str = base64.b64encode(encrypted_bytes).decode(encoding="ascii")
84+
except Exception as e:
85+
msg = "Failed to encrypt value"
86+
raise EncryptionError(msg) from e
8887

8988
return {
9089
_ENCRYPTED_DATA_KEY: base64_str,
@@ -96,34 +95,30 @@ def _decrypt_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None:
9695
if value is None:
9796
return None
9897

99-
if _ENCRYPTED_DATA_KEY not in value:
98+
if _ENCRYPTED_DATA_KEY not in value and isinstance(value, dict): # pyright: ignore[reportUnnecessaryIsInstance]
10099
return value
101100

102101
base64_str = value[_ENCRYPTED_DATA_KEY]
103102
if not isinstance(base64_str, str):
104-
# Corrupted data, return as-is
105103
msg = f"Corrupted data: expected str, got {type(base64_str)}"
106-
raise TypeError(msg)
104+
raise CorruptedEncryptionDataError(msg)
107105

108106
if _ENCRYPTION_VERSION_KEY not in value:
109107
msg = "Corrupted data: missing encryption version"
110-
raise TypeError(msg)
108+
raise CorruptedEncryptionDataError(msg)
111109

112110
encryption_version = value[_ENCRYPTION_VERSION_KEY]
113111
if not isinstance(encryption_version, int):
114-
# Corrupted data, return as-is
115112
msg = f"Corrupted data: expected int, got {type(encryption_version)}"
116-
raise TypeError(msg)
113+
raise CorruptedEncryptionDataError(msg)
117114

118115
try:
119-
# Decode from base64
120-
encrypted_bytes: bytes = base64.b64decode(base64_str)
116+
encrypted_bytes: bytes = base64.b64decode(base64_str, validate=True)
121117

122-
# Decrypt with Fernet
123118
json_bytes: bytes = self._decryption_fn(encrypted_bytes, encryption_version)
124119

125-
# Parse JSON
126120
json_str: str = json_bytes.decode(encoding="utf-8")
121+
127122
return json.loads(json_str) # type: ignore[no-any-return]
128123
except Exception as e:
129124
msg = "Failed to decrypt value"

key-value/key-value-aio/src/key_value/aio/wrappers/encryption/fernet.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from cryptography.fernet import Fernet
1+
from cryptography.fernet import Fernet, MultiFernet
22
from key_value.shared.errors.wrappers.encryption import EncryptionVersionError
33
from typing_extensions import overload
44

@@ -7,21 +7,26 @@
77

88
ENCRYPTION_VERSION = 1
99

10+
KDF_ITERATIONS = 1_200_000
11+
1012

1113
class FernetEncryptionWrapper(BaseEncryptionWrapper):
14+
"""Wrapper that encrypts values before storing and decrypts on retrieval using Fernet (symmetric encryption)."""
15+
1216
@overload
1317
def __init__(
1418
self,
1519
key_value: AsyncKeyValue,
1620
*,
17-
fernet: Fernet,
21+
fernet: Fernet | MultiFernet,
1822
raise_on_decryption_error: bool = True,
1923
) -> None:
2024
"""Initialize the Fernet encryption wrapper.
2125
2226
Args:
2327
key_value: The key-value store to wrap.
24-
fernet: The Fernet instance to use for encryption and decryption.
28+
fernet: The Fernet or MultiFernet instance to use for encryption and decryption MultiFernet is used to support
29+
key rotation by allowing you to provide multiple Fernet instances that are attempted in order.
2530
raise_on_decryption_error: Whether to raise an exception if decryption fails. Defaults to True.
2631
"""
2732

@@ -47,21 +52,21 @@ def __init__(
4752
self,
4853
key_value: AsyncKeyValue,
4954
*,
50-
fernet: Fernet | None = None,
55+
fernet: Fernet | MultiFernet | None = None,
5156
source_material: str | None = None,
5257
salt: str | None = None,
5358
raise_on_decryption_error: bool = True,
5459
) -> None:
5560
if fernet is not None: # noqa: SIM102
5661
if source_material or salt:
57-
msg = "Cannot provide both fernet and source_material and salt"
62+
msg = "Cannot provide fernet together with source_material or salt"
5863
raise ValueError(msg)
5964

6065
if fernet is None:
61-
if not source_material:
66+
if not source_material or not source_material.strip():
6267
msg = "Must provide either fernet or source_material"
6368
raise ValueError(msg)
64-
if not salt:
69+
if not salt or not salt.strip():
6570
msg = "Must provide a salt"
6671
raise ValueError(msg)
6772
fernet = Fernet(key=_generate_encryption_key(source_material=source_material, salt=salt))
@@ -85,7 +90,7 @@ def decrypt_with_fernet(data: bytes, encryption_version: int) -> bytes:
8590

8691

8792
def _generate_encryption_key(source_material: str, salt: str) -> bytes:
88-
"""Generate a Fernet encryption key from a source material and salt using PBKDF2 with 1.2 million iterations."""
93+
"""Generate a Fernet encryption key from a source material and salt using PBKDF2."""
8994
import base64
9095

9196
from cryptography.hazmat.primitives import hashes
@@ -95,7 +100,7 @@ def _generate_encryption_key(source_material: str, salt: str) -> bytes:
95100
algorithm=hashes.SHA256(),
96101
length=32,
97102
salt=salt.encode(),
98-
iterations=1_200_000,
103+
iterations=KDF_ITERATIONS,
99104
).derive(key_material=source_material.encode())
100105

101106
return base64.urlsafe_b64encode(pbkdf2)

key-value/key-value-aio/tests/stores/wrappers/test_encryption.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from cryptography.fernet import Fernet
2+
from cryptography.fernet import Fernet, MultiFernet
33
from dirty_equals import IsStr
44
from inline_snapshot import snapshot
55
from key_value.shared.errors.wrappers.encryption import DecryptionError
@@ -123,6 +123,19 @@ async def test_decryption_ignores_corrupted_data(self, memory_store: MemoryStore
123123

124124
assert await store.get(collection="test", key="test") is None
125125

126+
async def test_decryption_with_multi_fernet(self, memory_store: MemoryStore):
127+
"""Test that decryption works with a MultiFernet."""
128+
first_fernet = Fernet(key=Fernet.generate_key())
129+
first_fernet_store = FernetEncryptionWrapper(key_value=memory_store, fernet=first_fernet)
130+
original_value = {"test": "value"}
131+
await first_fernet_store.put(collection="test", key="test", value=original_value)
132+
assert await first_fernet_store.get(collection="test", key="test") == original_value
133+
134+
second_fernet = Fernet(key=Fernet.generate_key())
135+
multi_fernet = MultiFernet([second_fernet, first_fernet])
136+
multi_fernet_store = FernetEncryptionWrapper(key_value=memory_store, fernet=multi_fernet)
137+
assert await multi_fernet_store.get(collection="test", key="test") == original_value
138+
126139
async def test_decryption_with_wrong_key_raises_error(self, memory_store: MemoryStore):
127140
"""Test that decryption with the wrong key raises an error."""
128141
fernet1 = Fernet(key=Fernet.generate_key())

key-value/key-value-shared/src/key_value/shared/errors/wrappers/encryption.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ class DecryptionError(EncryptionError):
1111

1212
class EncryptionVersionError(EncryptionError):
1313
"""Exception raised when the encryption version is not supported."""
14+
15+
16+
class CorruptedEncryptionDataError(EncryptionError):
17+
"""Exception raised when the encrypted data is corrupted."""

key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def from_json(cls, json_str: str, includes_metadata: bool = True, ttl: SupportsF
8585

8686
def dump_to_json(obj: dict[str, Any]) -> str:
8787
try:
88-
return json.dumps(obj)
88+
return json.dumps(obj, separators=(",", ":"))
8989
except (json.JSONDecodeError, TypeError) as e:
9090
msg: str = f"Failed to serialize object to JSON: {e}"
9191
raise SerializationError(msg) from e

key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/encryption/base.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,14 @@
77
from typing import Any, SupportsFloat
88

99
from key_value.shared.errors.key_value import SerializationError
10-
from key_value.shared.errors.wrappers.encryption import DecryptionError
10+
from key_value.shared.errors.wrappers.encryption import CorruptedEncryptionDataError, DecryptionError
1111
from typing_extensions import override
1212

1313
from key_value.sync.code_gen.protocols.key_value import KeyValue
1414
from key_value.sync.code_gen.wrappers.base import BaseWrapper
1515

16-
# Special keys used to store encrypted data
1716
_ENCRYPTED_DATA_KEY = "__encrypted_data__"
1817
_ENCRYPTION_VERSION_KEY = "__encryption_version__"
19-
_ENCRYPTION_VERSION = 1
2018

2119
EncryptionFn = Callable[[bytes], bytes]
2220
DecryptionFn = Callable[[bytes, int], bytes]
@@ -29,18 +27,9 @@ class EncryptionError(Exception):
2927
class BaseEncryptionWrapper(BaseWrapper):
3028
"""Wrapper that encrypts values before storing and decrypts on retrieval.
3129
32-
This wrapper encrypts the JSON-serialized value using Fernet (symmetric encryption)
30+
This wrapper encrypts the JSON-serialized value using a custom encryption function
3331
and stores it as a base64-encoded string within a special key in the dictionary.
3432
This allows encryption while maintaining the dict[str, Any] interface.
35-
36-
The encrypted format looks like:
37-
{
38-
"__encrypted_data__": "base64-encoded-encrypted-data",
39-
"__encryption_version__": 1
40-
}
41-
42-
Note: The encryption key must be kept secret and secure. If the key is lost,
43-
encrypted data cannot be recovered.
4433
"""
4534

4635
def __init__(
@@ -72,22 +61,31 @@ def __init__(
7261
super().__init__()
7362

7463
def _encrypt_value(self, value: dict[str, Any]) -> dict[str, Any]:
75-
"""Encrypt a value into the encrypted format."""
64+
"""Encrypt a value into the encrypted format.
65+
66+
The encrypted format looks like:
67+
{
68+
"__encrypted_data__": "base64-encoded-encrypted-data",
69+
"__encryption_version__": 1
70+
}
71+
"""
7672

7773
# Serialize to JSON
7874
try:
7975
json_str: str = json.dumps(value, separators=(",", ":"))
76+
77+
json_bytes: bytes = json_str.encode(encoding="utf-8")
8078
except (json.JSONDecodeError, TypeError) as e:
8179
msg: str = f"Failed to serialize object to JSON: {e}"
8280
raise SerializationError(msg) from e
8381

84-
json_bytes: bytes = json_str.encode(encoding="utf-8")
85-
86-
# Encrypt with Fernet
87-
encrypted_bytes: bytes = self._encryption_fn(json_bytes)
82+
try:
83+
encrypted_bytes: bytes = self._encryption_fn(json_bytes)
8884

89-
# Encode to base64 for storage in dict (though Fernet output is already base64)
90-
base64_str: str = base64.b64encode(encrypted_bytes).decode(encoding="ascii")
85+
base64_str: str = base64.b64encode(encrypted_bytes).decode(encoding="ascii")
86+
except Exception as e:
87+
msg = "Failed to encrypt value"
88+
raise EncryptionError(msg) from e
9189

9290
return {_ENCRYPTED_DATA_KEY: base64_str, _ENCRYPTION_VERSION_KEY: self.encryption_version}
9391

@@ -96,34 +94,30 @@ def _decrypt_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None:
9694
if value is None:
9795
return None
9896

99-
if _ENCRYPTED_DATA_KEY not in value:
97+
if _ENCRYPTED_DATA_KEY not in value and isinstance(value, dict): # pyright: ignore[reportUnnecessaryIsInstance]
10098
return value
10199

102100
base64_str = value[_ENCRYPTED_DATA_KEY]
103101
if not isinstance(base64_str, str):
104-
# Corrupted data, return as-is
105102
msg = f"Corrupted data: expected str, got {type(base64_str)}"
106-
raise TypeError(msg)
103+
raise CorruptedEncryptionDataError(msg)
107104

108105
if _ENCRYPTION_VERSION_KEY not in value:
109106
msg = "Corrupted data: missing encryption version"
110-
raise TypeError(msg)
107+
raise CorruptedEncryptionDataError(msg)
111108

112109
encryption_version = value[_ENCRYPTION_VERSION_KEY]
113110
if not isinstance(encryption_version, int):
114-
# Corrupted data, return as-is
115111
msg = f"Corrupted data: expected int, got {type(encryption_version)}"
116-
raise TypeError(msg)
112+
raise CorruptedEncryptionDataError(msg)
117113

118114
try:
119-
# Decode from base64
120-
encrypted_bytes: bytes = base64.b64decode(base64_str)
115+
encrypted_bytes: bytes = base64.b64decode(base64_str, validate=True)
121116

122-
# Decrypt with Fernet
123117
json_bytes: bytes = self._decryption_fn(encrypted_bytes, encryption_version)
124118

125-
# Parse JSON
126119
json_str: str = json_bytes.decode(encoding="utf-8")
120+
127121
return json.loads(json_str) # type: ignore[no-any-return]
128122
except Exception as e:
129123
msg = "Failed to decrypt value"

0 commit comments

Comments
 (0)