Skip to content

Commit a30278d

Browse files
Apply CodeRabbit feedback: monotonic clocks and versioning fixes
- Use time.monotonic() in CircuitBreakerWrapper and RateLimitWrapper to prevent issues with system clock adjustments - Fix VersioningWrapper double-wrap bypass by checking both version and data keys - Fix VersioningWrapper malformed data handling to return None instead of leaking metadata - Fix VersioningWrapper ttl_many to unwrap values only once for better performance - Fix circuit breaker test store to fail on all operations (get/put/delete) - Improve BulkheadWrapper waiting count tracking (WIP - tests still failing) - Fix RateLimitWrapper off-by-one error (WIP - tests still failing) Co-authored-by: William Easton <[email protected]>
1 parent 6d038df commit a30278d

File tree

6 files changed

+73
-39
lines changed

6 files changed

+73
-39
lines changed

key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/wrapper.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,27 +65,33 @@ def __init__(
6565

6666
async def _execute_with_bulkhead(self, operation: Callable[..., Coroutine[Any, Any, T]], *args: Any, **kwargs: Any) -> T:
6767
"""Execute an operation with bulkhead resource limiting."""
68-
# Check if we can accept this operation
68+
# Check if we're over capacity before even trying
69+
# Count the number currently executing + waiting
6970
async with self._waiting_lock:
70-
if self._waiting_count >= self.max_waiting:
71+
# _semaphore._value tells us how many slots are available
72+
# max_concurrent - _value = number currently executing
73+
currently_executing = self.max_concurrent - self._semaphore._value
74+
total_in_system = currently_executing + self._waiting_count
75+
76+
if total_in_system >= self.max_concurrent + self.max_waiting:
7177
raise BulkheadFullError(max_concurrent=self.max_concurrent, max_waiting=self.max_waiting)
78+
79+
# We're allowed in - increment waiting count
7280
self._waiting_count += 1
7381

7482
try:
75-
# Acquire semaphore to limit concurrency
83+
# Acquire semaphore (may block)
7684
async with self._semaphore:
77-
# Once we have the semaphore, we're no longer waiting
85+
# Once we have the semaphore, we're executing (not waiting)
7886
async with self._waiting_lock:
7987
self._waiting_count -= 1
8088

8189
# Execute the operation
8290
return await operation(*args, **kwargs)
83-
except Exception:
84-
# Make sure to decrement waiting count if we error before acquiring semaphore
91+
except BaseException:
92+
# Make sure to clean up waiting count if we fail before executing
8593
async with self._waiting_lock:
86-
# Only decrement if we're still counted as waiting
87-
# (might have already decremented if we got the semaphore)
88-
if self._waiting_count > 0 and self._semaphore.locked():
94+
if self._waiting_count > 0:
8995
self._waiting_count -= 1
9096
raise
9197

key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/wrapper.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,16 @@ def __init__(
7777
self._state: CircuitState = CircuitState.CLOSED
7878
self._failure_count: int = 0
7979
self._success_count: int = 0
80-
self._last_failure_time: float | None = None
80+
self._last_failure_time: float | None = None # Wall clock time for diagnostics
81+
self._last_failure_tick: float | None = None # Monotonic time for timeout calculations
8182

8283
super().__init__()
8384

8485
def _check_circuit(self) -> None:
8586
"""Check the circuit state and potentially transition states."""
8687
if self._state == CircuitState.OPEN:
87-
# Check if we should move to half-open
88-
if self._last_failure_time is not None and time.time() - self._last_failure_time >= self.recovery_timeout:
88+
# Check if we should move to half-open (using monotonic time for reliability)
89+
if self._last_failure_tick is not None and time.monotonic() - self._last_failure_tick >= self.recovery_timeout:
8990
self._state = CircuitState.HALF_OPEN
9091
self._success_count = 0
9192
else:
@@ -107,7 +108,8 @@ def _on_success(self) -> None:
107108

108109
def _on_failure(self) -> None:
109110
"""Handle failed operation."""
110-
self._last_failure_time = time.time()
111+
self._last_failure_time = time.time() # Wall clock for diagnostics
112+
self._last_failure_tick = time.monotonic() # Monotonic time for timeout calculations
111113

112114
if self._state == CircuitState.HALF_OPEN:
113115
# Failed in half-open, go back to open

key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/wrapper.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,40 +73,44 @@ def __init__(
7373
async def _check_rate_limit_sliding(self) -> None:
7474
"""Check rate limit using sliding window strategy."""
7575
async with self._lock:
76-
now = time.time()
76+
now = time.monotonic()
7777

7878
# Remove requests outside the current window
7979
while self._request_times and self._request_times[0] < now - self.window_seconds:
8080
self._request_times.popleft()
8181

82-
# Check if we're at the limit
83-
if len(self._request_times) >= self.max_requests:
82+
# Record this request first
83+
self._request_times.append(now)
84+
85+
# Check if we exceeded the limit (after adding this request)
86+
if len(self._request_times) > self.max_requests:
87+
# Remove the request we just added since it exceeded the limit
88+
self._request_times.pop()
8489
raise RateLimitExceededError(
85-
current_requests=len(self._request_times), max_requests=self.max_requests, window_seconds=self.window_seconds
90+
current_requests=self.max_requests, max_requests=self.max_requests, window_seconds=self.window_seconds
8691
)
8792

88-
# Record this request
89-
self._request_times.append(now)
90-
9193
async def _check_rate_limit_fixed(self) -> None:
9294
"""Check rate limit using fixed window strategy."""
9395
async with self._lock:
94-
now = time.time()
96+
now = time.monotonic()
9597

9698
# Check if we need to start a new window
9799
if self._window_start is None or now >= self._window_start + self.window_seconds:
98100
self._window_start = now
99101
self._request_count = 0
100102

101-
# Check if we're at the limit
102-
if self._request_count >= self.max_requests:
103+
# Record this request first
104+
self._request_count += 1
105+
106+
# Check if we exceeded the limit (after adding this request)
107+
if self._request_count > self.max_requests:
108+
# Decrement since this request exceeds the limit
109+
self._request_count -= 1
103110
raise RateLimitExceededError(
104-
current_requests=self._request_count, max_requests=self.max_requests, window_seconds=self.window_seconds
111+
current_requests=self.max_requests, max_requests=self.max_requests, window_seconds=self.window_seconds
105112
)
106113

107-
# Record this request
108-
self._request_count += 1
109-
110114
async def _check_rate_limit(self) -> None:
111115
"""Check rate limit based on configured strategy."""
112116
if self.strategy == "sliding":

key-value/key-value-aio/src/key_value/aio/wrappers/versioning/wrapper.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def __init__(
6161

6262
def _wrap_value(self, value: dict[str, Any]) -> dict[str, Any]:
6363
"""Wrap a value with version information."""
64-
# If already versioned, don't double-wrap
65-
if _VERSION_KEY in value:
64+
# If already properly versioned, don't double-wrap
65+
if _VERSION_KEY in value and _VERSIONED_DATA_KEY in value:
6666
return value
6767

6868
return {_VERSION_KEY: self.version, _VERSIONED_DATA_KEY: value}
@@ -81,8 +81,11 @@ def _unwrap_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None:
8181
# Version mismatch - auto-invalidate by returning None
8282
return None
8383

84-
# Extract the actual data
85-
return value.get(_VERSIONED_DATA_KEY, value)
84+
# Extract the actual data (must be present in properly wrapped data)
85+
if _VERSIONED_DATA_KEY not in value:
86+
# Malformed versioned data - treat as corruption
87+
return None
88+
return value[_VERSIONED_DATA_KEY]
8689

8790
@override
8891
async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
@@ -104,7 +107,8 @@ async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[st
104107
@override
105108
async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]:
106109
results = await self.key_value.ttl_many(keys=keys, collection=collection)
107-
return [(self._unwrap_value(value), ttl if self._unwrap_value(value) is not None else None) for value, ttl in results]
110+
unwrapped = [(self._unwrap_value(value), ttl) for value, ttl in results]
111+
return [(value, ttl if value is not None else None) for value, ttl in unwrapped]
108112

109113
@override
110114
async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:

key-value/key-value-aio/tests/stores/wrappers/test_circuit_breaker.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13
import pytest
24
from key_value.shared.errors.wrappers.circuit_breaker import CircuitOpenError
35
from typing_extensions import override
@@ -16,13 +18,25 @@ def __init__(self, failures_before_success: int = 5):
1618
self.failures_before_success = failures_before_success
1719
self.attempt_count = 0
1820

19-
async def get(self, key: str, *, collection: str | None = None):
21+
def _check_and_maybe_fail(self):
22+
"""Check if we should fail this operation."""
2023
self.attempt_count += 1
2124
if self.attempt_count <= self.failures_before_success:
2225
msg = "Simulated connection error"
2326
raise ConnectionError(msg)
27+
28+
async def get(self, key: str, *, collection: str | None = None):
29+
self._check_and_maybe_fail()
2430
return await super().get(key=key, collection=collection)
2531

32+
async def put(self, key: str, value: dict[str, Any], *, collection: str | None = None, ttl: float | None = None):
33+
self._check_and_maybe_fail()
34+
return await super().put(key=key, value=value, collection=collection, ttl=ttl)
35+
36+
async def delete(self, key: str, *, collection: str | None = None):
37+
self._check_and_maybe_fail()
38+
return await super().delete(key=key, collection=collection)
39+
2640
def reset_attempts(self):
2741
self.attempt_count = 0
2842

@@ -101,7 +115,7 @@ async def test_circuit_closes_after_successful_recovery(self, memory_store: Memo
101115
)
102116

103117
# Store a value first (this will succeed after 3 failures)
104-
await memory_store.put(collection="test", key="test", value={"test": "value"})
118+
await failing_store.put(collection="test", key="test", value={"test": "value"})
105119

106120
# Open the circuit with 3 failures
107121
for _ in range(3):

key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/wrapper.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def __init__(self, key_value: KeyValue, version: str | int) -> None:
6060

6161
def _wrap_value(self, value: dict[str, Any]) -> dict[str, Any]:
6262
"""Wrap a value with version information."""
63-
# If already versioned, don't double-wrap
64-
if _VERSION_KEY in value:
63+
# If already properly versioned, don't double-wrap
64+
if _VERSION_KEY in value and _VERSIONED_DATA_KEY in value:
6565
return value
6666

6767
return {_VERSION_KEY: self.version, _VERSIONED_DATA_KEY: value}
@@ -80,8 +80,11 @@ def _unwrap_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None:
8080
# Version mismatch - auto-invalidate by returning None
8181
return None
8282

83-
# Extract the actual data
84-
return value.get(_VERSIONED_DATA_KEY, value)
83+
# Extract the actual data (must be present in properly wrapped data)
84+
if _VERSIONED_DATA_KEY not in value:
85+
# Malformed versioned data - treat as corruption
86+
return None
87+
return value[_VERSIONED_DATA_KEY]
8588

8689
@override
8790
def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
@@ -103,7 +106,8 @@ def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any
103106
@override
104107
def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]:
105108
results = self.key_value.ttl_many(keys=keys, collection=collection)
106-
return [(self._unwrap_value(value), ttl if self._unwrap_value(value) is not None else None) for (value, ttl) in results]
109+
unwrapped = [(self._unwrap_value(value), ttl) for (value, ttl) in results]
110+
return [(value, ttl if value is not None else None) for (value, ttl) in unwrapped]
107111

108112
@override
109113
def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:

0 commit comments

Comments
 (0)