Skip to content

Commit 4af690a

Browse files
feat: Enforce BearType on Store/Wrapper methods (#126)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: William Easton <[email protected]>
1 parent 6505d3d commit 4af690a

File tree

8 files changed

+57
-28
lines changed

8 files changed

+57
-28
lines changed

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,12 @@ needs while keeping your framework code clean and backend-agnostic.
101101
- **No Live Objects**: Even when using the in-memory store, "live" objects are
102102
never returned from the store. You get a dictionary or a Pydantic model,
103103
hopefully a copy of what you stored, but never the same instance in memory.
104-
- **Dislike of Bear Bros**: Beartype is used for runtime type checking, it will
105-
report warnings if you get too cheeky with what you're passing around. If you
106-
are not a fan of beartype, you can disable it by setting the
107-
`PY_KEY_VALUE_DISABLE_BEARTYPE` environment variable to `true` or you can
108-
disable the warnings via the warn module.
104+
- **Dislike of Bear Bros**: Beartype is used for runtime type checking. Core
105+
protocol methods in store and wrapper implementations (put/get/delete/ttl
106+
and their batch variants) enforce types and will raise TypeError for
107+
violations. Other code produces warnings. You can disable all beartype
108+
checks by setting `PY_KEY_VALUE_DISABLE_BEARTYPE=true` or suppress warnings
109+
via the warnings module.
109110

110111
## Installation
111112

key-value/key-value-aio/src/key_value/aio/stores/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from key_value.shared.constants import DEFAULT_COLLECTION_NAME
1313
from key_value.shared.errors import StoreSetupError
14+
from key_value.shared.type_checking.bear_spray import bear_enforce
1415
from key_value.shared.utils.managed_entry import ManagedEntry
1516
from key_value.shared.utils.time_to_live import now, prepare_ttl
1617
from typing_extensions import Self, override
@@ -134,6 +135,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
134135

135136
return [await self._get_managed_entry(collection=collection, key=key) for key in keys]
136137

138+
@bear_enforce
137139
@override
138140
async def get(
139141
self,
@@ -163,6 +165,7 @@ async def get(
163165

164166
return dict(managed_entry.value)
165167

168+
@bear_enforce
166169
@override
167170
async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]:
168171
collection = collection or self.default_collection
@@ -171,6 +174,7 @@ async def get_many(self, keys: Sequence[str], *, collection: str | None = None)
171174
entries = await self._get_managed_entries(keys=keys, collection=collection)
172175
return [dict(entry.value) if entry and not entry.is_expired else None for entry in entries]
173176

177+
@bear_enforce
174178
@override
175179
async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]:
176180
collection = collection or self.default_collection
@@ -183,6 +187,7 @@ async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[st
183187

184188
return (dict(managed_entry.value), managed_entry.ttl)
185189

190+
@bear_enforce
186191
@override
187192
async def ttl_many(
188193
self,
@@ -216,6 +221,7 @@ async def _put_managed_entries(self, *, collection: str, keys: Sequence[str], ma
216221
managed_entry=managed_entry,
217222
)
218223

224+
@bear_enforce
219225
@override
220226
async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
221227
"""Store a key-value pair in the specified collection with optional TTL."""
@@ -245,6 +251,7 @@ def _prepare_put_many(
245251

246252
return (keys, values, ttl_for_entries)
247253

254+
@bear_enforce
248255
@override
249256
async def put_many(
250257
self,
@@ -281,13 +288,15 @@ async def _delete_managed_entries(self, *, keys: Sequence[str], collection: str)
281288

282289
return deleted_count
283290

291+
@bear_enforce
284292
@override
285293
async def delete(self, key: str, *, collection: str | None = None) -> bool:
286294
collection = collection or self.default_collection
287295
await self.setup_collection(collection=collection)
288296

289297
return await self._delete_managed_entry(key=key, collection=collection)
290298

299+
@bear_enforce
291300
@override
292301
async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int:
293302
"""Delete multiple managed entries by key from the specified collection."""

key-value/key-value-aio/src/key_value/aio/wrappers/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Mapping, Sequence
22
from typing import Any, SupportsFloat
33

4+
from key_value.shared.type_checking.bear_spray import bear_enforce
45
from typing_extensions import override
56

67
from key_value.aio.protocols.key_value import AsyncKeyValue
@@ -11,26 +12,32 @@ class BaseWrapper(AsyncKeyValue):
1112

1213
key_value: AsyncKeyValue
1314

15+
@bear_enforce
1416
@override
1517
async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
1618
return await self.key_value.get(collection=collection, key=key)
1719

20+
@bear_enforce
1821
@override
1922
async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]:
2023
return await self.key_value.get_many(collection=collection, keys=keys)
2124

25+
@bear_enforce
2226
@override
2327
async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]:
2428
return await self.key_value.ttl(collection=collection, key=key)
2529

30+
@bear_enforce
2631
@override
2732
async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]:
2833
return await self.key_value.ttl_many(collection=collection, keys=keys)
2934

35+
@bear_enforce
3036
@override
3137
async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
3238
return await self.key_value.put(collection=collection, key=key, value=value, ttl=ttl)
3339

40+
@bear_enforce
3441
@override
3542
async def put_many(
3643
self,
@@ -42,10 +49,12 @@ async def put_many(
4249
) -> None:
4350
return await self.key_value.put_many(keys=keys, values=values, collection=collection, ttl=ttl)
4451

52+
@bear_enforce
4553
@override
4654
async def delete(self, key: str, *, collection: str | None = None) -> bool:
4755
return await self.key_value.delete(collection=collection, key=key)
4856

57+
@bear_enforce
4958
@override
5059
async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int:
5160
return await self.key_value.delete_many(keys=keys, collection=collection)

key-value/key-value-shared-test/src/key_value/shared_test/cases.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,8 @@ def parametrize(cls, cases: list[Self]) -> MarkDecorator:
106106
data={"str_key": {"str_key_2": None, "str_key_3": None}},
107107
json='{"str_key": {"str_key_2": null, "str_key_3": null}}',
108108
),
109-
Case(
110-
name="nested-dict-keys",
111-
data={"str_key": {"str_key_2": {None: "str_value"}}},
112-
json='{"str_key": {"str_key_2": {"null": "str_value"}}}',
113-
round_trip={"str_key": {"str_key_2": {"null": "str_value"}}},
114-
),
115109
Case(name="no-implicit-serialization-in-keys", data={"null": "str_value"}, json='{"null": "str_value"}'),
116110
Case(name="no-implicit-serialization-in-values", data={"str_key": "null"}, json='{"str_key": "null"}'),
117-
Case(name="implicit-serialization-of-null-key", data={None: True}, json='{"null": true}', round_trip={"null": True}), # type: ignore
118111
case_type="null",
119112
)
120113

@@ -129,16 +122,8 @@ def parametrize(cls, cases: list[Self]) -> MarkDecorator:
129122
data={"str_key": {"str_key_2": True, "str_key_3": False}},
130123
json='{"str_key": {"str_key_2": true, "str_key_3": false}}',
131124
),
132-
Case(
133-
name="nested-dict-keys",
134-
data={"str_key": {"str_key_2": {True: "str_value"}}},
135-
json='{"str_key": {"str_key_2": {"true": "str_value"}}}',
136-
round_trip={"str_key": {"str_key_2": {"true": "str_value"}}},
137-
),
138125
Case(name="no-implicit-serialization-in-keys", data={"true": "str_value"}, json='{"true": "str_value"}'),
139126
Case(name="no-implicit-serialization-in-values", data={"str_key": "true"}, json='{"str_key": "true"}'),
140-
Case(name="implicit-serialization-of-true-key", data={True: True}, json='{"true": true}', round_trip={"true": True}), # type: ignore
141-
Case(name="implicit-serialization-of-false-key", data={False: True}, json='{"false": true}', round_trip={"false": True}), # type: ignore
142127
case_type="boolean",
143128
)
144129

@@ -150,7 +135,6 @@ def parametrize(cls, cases: list[Self]) -> MarkDecorator:
150135
Case(name="int", data={"int_key": 1}, json='{"int_key": 1}'),
151136
Case(name="value-negative", data={"negative_int_key": -42}, json='{"negative_int_key": -42}'),
152137
Case(name="large-value", data={"large_int_key": 1 * 10**18}, json=f'{{"large_int_key": {1 * 10**18}}}'),
153-
Case(name="key", data={1: True}, json='{"1": true}', round_trip={"1": True}), # type: ignore
154138
Case(name="no-implicit-serialization-in-keys", data={"1": "str_value"}, json='{"1": "str_value"}'),
155139
Case(name="no-implicit-serialization-in-values", data={"str_key": "1"}, json='{"str_key": "1"}'),
156140
case_type="integer",
@@ -165,7 +149,6 @@ def parametrize(cls, cases: list[Self]) -> MarkDecorator:
165149
Case(name="large-value", data={"large_float_key": 1.0 * 10**63}, json=f'{{"large_float_key": {1.0 * 10**63}}}'),
166150
Case(name="no-implicit-serialization-in-keys", data={"1.0": "str_value"}, json='{"1.0": "str_value"}'),
167151
Case(name="no-implicit-serialization-in-values", data={"str_key": "1.0"}, json='{"str_key": "1.0"}'),
168-
Case(name="implicit-serialization-of-float-key", data={1.0: True}, json='{"1.0": true}', round_trip={"1.0": True}), # type: ignore
169152
case_type="float",
170153
)
171154

@@ -189,11 +172,8 @@ def parametrize(cls, cases: list[Self]) -> MarkDecorator:
189172
DATETIME_CASES: PositiveCases = PositiveCases(case_type="datetime")
190173

191174
NEGATIVE_DATETIME_CASES: NegativeCases = NegativeCases(
192-
NegativeCase(name="datetime-key", data={FIXED_DATETIME: True}), # type: ignore
193175
NegativeCase(name="datetime-value", data={"str_key": FIXED_DATETIME}),
194-
NegativeCase(name="date-key", data={FIXED_DATETIME.date(): True}), # type: ignore
195176
NegativeCase(name="date-value", data={"str_key": FIXED_DATETIME.date()}),
196-
NegativeCase(name="time-key", data={FIXED_TIME: True}), # type: ignore
197177
NegativeCase(name="time-value", data={"str_key": FIXED_TIME}),
198178
case_type="datetime",
199179
)
@@ -213,7 +193,6 @@ def parametrize(cls, cases: list[Self]) -> MarkDecorator:
213193
)
214194

215195
NEGATIVE_UUID_CASES: NegativeCases = NegativeCases(
216-
NegativeCase(name="key", data={FIXED_UUID: True}), # type: ignore
217196
NegativeCase(name="value", data={"str_key": FIXED_UUID}),
218197
case_type="uuid",
219198
)
@@ -228,7 +207,6 @@ def parametrize(cls, cases: list[Self]) -> MarkDecorator:
228207
)
229208

230209
NEGATIVE_BYTES_CASES: NegativeCases = NegativeCases(
231-
NegativeCase(name="bytes-key", data={B_HELLO_WORLD: True}), # type: ignore
232210
NegativeCase(name="bytes-value", data={"str_key": B_HELLO_WORLD}),
233211
case_type="bytes",
234212
)
@@ -246,7 +224,6 @@ def parametrize(cls, cases: list[Self]) -> MarkDecorator:
246224
)
247225

248226
NEGATIVE_TUPLE_CASES: NegativeCases = NegativeCases(
249-
NegativeCase(name="key", data={SAMPLE_TUPLE: True}), # type: ignore
250227
NegativeCase(name="value", data={"str_key": SAMPLE_TUPLE}),
251228
case_type="tuple",
252229
)

key-value/key-value-shared/src/key_value/shared/type_checking/bear_spray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
no_bear_type = beartype(conf=no_bear_type_check_conf)
99

10+
enforce_bear_type_conf = BeartypeConf(strategy=BeartypeStrategy.O1, violation_type=TypeError)
11+
enforce_bear_type = beartype(conf=enforce_bear_type_conf)
12+
1013
P = ParamSpec("P")
1114
R = TypeVar("R")
1215

@@ -15,4 +18,9 @@ def no_bear_type_check(func: Callable[P, R]) -> Callable[P, R]:
1518
return no_bear_type(func)
1619

1720

21+
def bear_enforce(func: Callable[P, R]) -> Callable[P, R]:
22+
"""Enforce beartype with exceptions instead of warnings."""
23+
return enforce_bear_type(func)
24+
25+
1826
bear_spray = no_bear_type_check

key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from key_value.shared.constants import DEFAULT_COLLECTION_NAME
1616
from key_value.shared.errors import StoreSetupError
17+
from key_value.shared.type_checking.bear_spray import bear_enforce
1718
from key_value.shared.utils.managed_entry import ManagedEntry
1819
from key_value.shared.utils.time_to_live import now, prepare_ttl
1920
from typing_extensions import Self, override
@@ -140,6 +141,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[
140141

141142
return [self._get_managed_entry(collection=collection, key=key) for key in keys]
142143

144+
@bear_enforce
143145
@override
144146
def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
145147
"""Retrieve a value by key from the specified collection.
@@ -164,6 +166,7 @@ def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | No
164166

165167
return dict(managed_entry.value)
166168

169+
@bear_enforce
167170
@override
168171
def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]:
169172
collection = collection or self.default_collection
@@ -172,6 +175,7 @@ def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> lis
172175
entries = self._get_managed_entries(keys=keys, collection=collection)
173176
return [dict(entry.value) if entry and (not entry.is_expired) else None for entry in entries]
174177

178+
@bear_enforce
175179
@override
176180
def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]:
177181
collection = collection or self.default_collection
@@ -184,6 +188,7 @@ def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any
184188

185189
return (dict(managed_entry.value), managed_entry.ttl)
186190

191+
@bear_enforce
187192
@override
188193
def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]:
189194
"""Retrieve multiple values and TTLs by key from the specified collection.
@@ -208,6 +213,7 @@ def _put_managed_entries(self, *, collection: str, keys: Sequence[str], managed_
208213
for key, managed_entry in zip(keys, managed_entries, strict=True):
209214
self._put_managed_entry(collection=collection, key=key, managed_entry=managed_entry)
210215

216+
@bear_enforce
211217
@override
212218
def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
213219
"""Store a key-value pair in the specified collection with optional TTL."""
@@ -233,6 +239,7 @@ def _prepare_put_many(
233239

234240
return (keys, values, ttl_for_entries)
235241

242+
@bear_enforce
236243
@override
237244
def put_many(
238245
self, keys: Sequence[str], values: Sequence[Mapping[str, Any]], *, collection: str | None = None, ttl: SupportsFloat | None = None
@@ -264,13 +271,15 @@ def _delete_managed_entries(self, *, keys: Sequence[str], collection: str) -> in
264271

265272
return deleted_count
266273

274+
@bear_enforce
267275
@override
268276
def delete(self, key: str, *, collection: str | None = None) -> bool:
269277
collection = collection or self.default_collection
270278
self.setup_collection(collection=collection)
271279

272280
return self._delete_managed_entry(key=key, collection=collection)
273281

282+
@bear_enforce
274283
@override
275284
def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int:
276285
"""Delete multiple managed entries by key from the specified collection."""

key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Mapping, Sequence
55
from typing import Any, SupportsFloat
66

7+
from key_value.shared.type_checking.bear_spray import bear_enforce
78
from typing_extensions import override
89

910
from key_value.sync.code_gen.protocols.key_value import KeyValue
@@ -14,36 +15,44 @@ class BaseWrapper(KeyValue):
1415

1516
key_value: KeyValue
1617

18+
@bear_enforce
1719
@override
1820
def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
1921
return self.key_value.get(collection=collection, key=key)
2022

23+
@bear_enforce
2124
@override
2225
def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]:
2326
return self.key_value.get_many(collection=collection, keys=keys)
2427

28+
@bear_enforce
2529
@override
2630
def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]:
2731
return self.key_value.ttl(collection=collection, key=key)
2832

33+
@bear_enforce
2934
@override
3035
def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]:
3136
return self.key_value.ttl_many(collection=collection, keys=keys)
3237

38+
@bear_enforce
3339
@override
3440
def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
3541
return self.key_value.put(collection=collection, key=key, value=value, ttl=ttl)
3642

43+
@bear_enforce
3744
@override
3845
def put_many(
3946
self, keys: Sequence[str], values: Sequence[Mapping[str, Any]], *, collection: str | None = None, ttl: SupportsFloat | None = None
4047
) -> None:
4148
return self.key_value.put_many(keys=keys, values=values, collection=collection, ttl=ttl)
4249

50+
@bear_enforce
4351
@override
4452
def delete(self, key: str, *, collection: str | None = None) -> bool:
4553
return self.key_value.delete(collection=collection, key=key)
4654

55+
@bear_enforce
4756
@override
4857
def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int:
4958
return self.key_value.delete_many(keys=keys, collection=collection)

sonar-project.properties

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SonarQube project configuration
2+
sonar.projectKey=strawgate_py-key-value
3+
sonar.organization=strawgate
4+
5+
# Exclude code-generated sync library from duplication detection
6+
# The sync library is automatically generated from the async library
7+
sonar.cpd.exclusions=**/key-value-sync/**/*.py

0 commit comments

Comments
 (0)