Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from key_value.aio.wrappers.bulkhead.wrapper import BulkheadWrapper

__all__ = ["BulkheadWrapper"]
135 changes: 135 additions & 0 deletions key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import asyncio
from collections.abc import Callable, Coroutine, Mapping, Sequence
from typing import Any, SupportsFloat, TypeVar

from key_value.shared.errors.wrappers.bulkhead import BulkheadFullError
from typing_extensions import override

from key_value.aio.protocols.key_value import AsyncKeyValue
from key_value.aio.wrappers.base import BaseWrapper

T = TypeVar("T")


class BulkheadWrapper(BaseWrapper):
"""Wrapper that implements the bulkhead pattern to isolate operations with resource pools.

This wrapper limits the number of concurrent operations and queued operations to prevent
resource exhaustion and isolate failures. The bulkhead pattern is inspired by ship bulkheads
that prevent a single hull breach from sinking the entire ship.

Benefits:
- Prevents a single slow or failing backend from consuming all resources
- Limits concurrent requests to protect backend from overload
- Provides bounded queue to prevent unbounded memory growth
- Enables graceful degradation under high load

Example:
bulkhead = BulkheadWrapper(
key_value=store,
max_concurrent=10, # Max 10 concurrent operations
max_waiting=20, # Max 20 operations can wait in queue
)

try:
await bulkhead.get(key="mykey")
except BulkheadFullError:
# Too many concurrent operations, system is overloaded
# Handle gracefully (return cached value, error response, etc.)
pass
"""

def __init__(
self,
key_value: AsyncKeyValue,
max_concurrent: int = 10,
max_waiting: int = 20,
) -> None:
"""Initialize the bulkhead wrapper.

Args:
key_value: The store to wrap.
max_concurrent: Maximum number of concurrent operations. Defaults to 10.
max_waiting: Maximum number of operations that can wait in queue. Defaults to 20.
"""
self.key_value: AsyncKeyValue = key_value
self.max_concurrent: int = max_concurrent
self.max_waiting: int = max_waiting

# Use semaphore to limit concurrent operations
self._semaphore: asyncio.Semaphore = asyncio.Semaphore(max_concurrent)
self._waiting_count: int = 0
self._waiting_lock: asyncio.Lock = asyncio.Lock()

super().__init__()

async def _execute_with_bulkhead(self, operation: Callable[..., Coroutine[Any, Any, T]], *args: Any, **kwargs: Any) -> T:
"""Execute an operation with bulkhead resource limiting."""
# Check if we're over capacity before even trying
# Count the number currently executing + waiting
async with self._waiting_lock:
# _semaphore._value tells us how many slots are available
# max_concurrent - _value = number currently executing
currently_executing = self.max_concurrent - self._semaphore._value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid accessing private semaphore internals.

Accessing self._semaphore._value relies on private implementation details of asyncio.Semaphore that may change across Python versions. Consider using locked() or tracking concurrent operations with your own counter instead.

For example, track concurrent operations explicitly:

# In __init__:
self._executing_count: int = 0

# In _execute_with_bulkhead:
async with self._waiting_lock:
    currently_executing = self._executing_count
    total_in_system = currently_executing + self._waiting_count
    
    if total_in_system >= self.max_concurrent + self.max_waiting:
        raise BulkheadFullError(...)
    
    self._waiting_count += 1
    registered_waiting = True

try:
    async with self._semaphore:
        async with self._waiting_lock:
            self._waiting_count -= 1
            self._executing_count += 1
            registered_waiting = False
        
        try:
            return await operation(*args, **kwargs)
        finally:
            async with self._waiting_lock:
                self._executing_count -= 1
except BaseException:
    if registered_waiting:
        async with self._waiting_lock:
            self._waiting_count -= 1
    raise
🤖 Prompt for AI Agents
In key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/wrapper.py around
line 73, the code reads self._semaphore._value which accesses private Semaphore
internals; replace this by tracking concurrent and waiting counts explicitly
under the existing _waiting_lock: add an _executing_count attribute initialized
in __init__, use _waiting_lock to read currently_executing and total_in_system
to enforce limits, increment _waiting_count when registering, then when entering
the semaphore decrement _waiting_count and increment _executing_count, and in a
finally block decrement _executing_count; also ensure any exception path that
aborts before acquiring the semaphore decrements _waiting_count under the lock
so counters remain consistent.

total_in_system = currently_executing + self._waiting_count

if total_in_system >= self.max_concurrent + self.max_waiting:
raise BulkheadFullError(max_concurrent=self.max_concurrent, max_waiting=self.max_waiting)

# We're allowed in - increment waiting count
self._waiting_count += 1

try:
# Acquire semaphore (may block)
async with self._semaphore:
# Once we have the semaphore, we're executing (not waiting)
async with self._waiting_lock:
self._waiting_count -= 1

# Execute the operation
return await operation(*args, **kwargs)
except BaseException:
# Make sure to clean up waiting count if we fail before executing
async with self._waiting_lock:
if self._waiting_count > 0:
self._waiting_count -= 1
raise

@override
async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
return await self._execute_with_bulkhead(self.key_value.get, key=key, collection=collection)

@override
async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]:
return await self._execute_with_bulkhead(self.key_value.get_many, keys=keys, collection=collection)

@override
async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]:
return await self._execute_with_bulkhead(self.key_value.ttl, key=key, collection=collection)

@override
async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]:
return await self._execute_with_bulkhead(self.key_value.ttl_many, keys=keys, collection=collection)

@override
async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
return await self._execute_with_bulkhead(self.key_value.put, key=key, value=value, collection=collection, ttl=ttl)

@override
async def put_many(
self,
keys: Sequence[str],
values: Sequence[Mapping[str, Any]],
*,
collection: str | None = None,
ttl: SupportsFloat | None = None,
) -> None:
return await self._execute_with_bulkhead(self.key_value.put_many, keys=keys, values=values, collection=collection, ttl=ttl)

@override
async def delete(self, key: str, *, collection: str | None = None) -> bool:
return await self._execute_with_bulkhead(self.key_value.delete, key=key, collection=collection)

@override
async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int:
return await self._execute_with_bulkhead(self.key_value.delete_many, keys=keys, collection=collection)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from key_value.aio.wrappers.circuit_breaker.wrapper import CircuitBreakerWrapper

__all__ = ["CircuitBreakerWrapper"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import time
from collections.abc import Callable, Coroutine, Mapping, Sequence
from enum import Enum
from typing import Any, SupportsFloat, TypeVar

from key_value.shared.errors.wrappers.circuit_breaker import CircuitOpenError
from typing_extensions import override

from key_value.aio.protocols.key_value import AsyncKeyValue
from key_value.aio.wrappers.base import BaseWrapper

T = TypeVar("T")


class CircuitState(Enum):
"""States for the circuit breaker."""

CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, blocking requests
HALF_OPEN = "half_open" # Testing if service recovered


class CircuitBreakerWrapper(BaseWrapper):
"""Wrapper that implements the circuit breaker pattern to prevent cascading failures.

This wrapper tracks operation failures and opens the circuit after a threshold of consecutive
failures. When the circuit is open, requests are blocked immediately without attempting the
operation. After a recovery timeout, the circuit moves to half-open state to test if the
backend has recovered.

The circuit breaker pattern is essential for production resilience as it:
- Prevents cascading failures when a backend becomes unhealthy
- Reduces load on failing backends, giving them time to recover
- Provides fast failure responses instead of waiting for timeouts
- Automatically attempts recovery after a configured timeout

Example:
circuit_breaker = CircuitBreakerWrapper(
key_value=store,
failure_threshold=5, # Open after 5 consecutive failures
recovery_timeout=30.0, # Try recovery after 30 seconds
success_threshold=2, # Close after 2 successes in half-open
)

try:
value = await circuit_breaker.get(key="mykey")
except CircuitOpenError:
# Circuit is open, backend is considered unhealthy
# Handle gracefully (use cache, return default, etc.)
pass
"""

def __init__(
self,
key_value: AsyncKeyValue,
failure_threshold: int = 5,
recovery_timeout: float = 30.0,
success_threshold: int = 2,
error_types: tuple[type[Exception], ...] = (Exception,),
) -> None:
"""Initialize the circuit breaker wrapper.

Args:
key_value: The store to wrap.
failure_threshold: Number of consecutive failures before opening the circuit. Defaults to 5.
recovery_timeout: Seconds to wait before attempting recovery (moving to half-open). Defaults to 30.0.
success_threshold: Number of consecutive successes in half-open state before closing the circuit. Defaults to 2.
error_types: Tuple of exception types that count as failures. Defaults to (Exception,).
"""
self.key_value: AsyncKeyValue = key_value
self.failure_threshold: int = failure_threshold
self.recovery_timeout: float = recovery_timeout
self.success_threshold: int = success_threshold
self.error_types: tuple[type[Exception], ...] = error_types

# Circuit state
self._state: CircuitState = CircuitState.CLOSED
self._failure_count: int = 0
self._success_count: int = 0
self._last_failure_time: float | None = None # Wall clock time for diagnostics
self._last_failure_tick: float | None = None # Monotonic time for timeout calculations

super().__init__()

def _check_circuit(self) -> None:
"""Check the circuit state and potentially transition states."""
if self._state == CircuitState.OPEN:
# Check if we should move to half-open (using monotonic time for reliability)
if self._last_failure_tick is not None and time.monotonic() - self._last_failure_tick >= self.recovery_timeout:
self._state = CircuitState.HALF_OPEN
self._success_count = 0
else:
# Circuit is still open, raise error
raise CircuitOpenError(failure_count=self._failure_count, last_failure_time=self._last_failure_time)

def _on_success(self) -> None:
"""Handle successful operation."""
if self._state == CircuitState.HALF_OPEN:
self._success_count += 1
if self._success_count >= self.success_threshold:
# Close the circuit
self._state = CircuitState.CLOSED
self._failure_count = 0
self._success_count = 0
elif self._state == CircuitState.CLOSED:
# Reset failure count on success
self._failure_count = 0

def _on_failure(self) -> None:
"""Handle failed operation."""
self._last_failure_time = time.time() # Wall clock for diagnostics
self._last_failure_tick = time.monotonic() # Monotonic time for timeout calculations

if self._state == CircuitState.HALF_OPEN:
# Failed in half-open, go back to open
self._state = CircuitState.OPEN
self._success_count = 0
elif self._state == CircuitState.CLOSED:
self._failure_count += 1
if self._failure_count >= self.failure_threshold:
# Open the circuit
self._state = CircuitState.OPEN

async def _execute_with_circuit_breaker(self, operation: Callable[..., Coroutine[Any, Any, T]], *args: Any, **kwargs: Any) -> T:
"""Execute an operation with circuit breaker logic."""
self._check_circuit()

try:
result = await operation(*args, **kwargs)
except self.error_types:
self._on_failure()
raise
else:
self._on_success()
return result

@override
async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
return await self._execute_with_circuit_breaker(self.key_value.get, key=key, collection=collection)

@override
async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]:
return await self._execute_with_circuit_breaker(self.key_value.get_many, keys=keys, collection=collection)

@override
async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]:
return await self._execute_with_circuit_breaker(self.key_value.ttl, key=key, collection=collection)

@override
async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]:
return await self._execute_with_circuit_breaker(self.key_value.ttl_many, keys=keys, collection=collection)

@override
async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
return await self._execute_with_circuit_breaker(self.key_value.put, key=key, value=value, collection=collection, ttl=ttl)

@override
async def put_many(
self,
keys: Sequence[str],
values: Sequence[Mapping[str, Any]],
*,
collection: str | None = None,
ttl: SupportsFloat | None = None,
) -> None:
return await self._execute_with_circuit_breaker(self.key_value.put_many, keys=keys, values=values, collection=collection, ttl=ttl)

@override
async def delete(self, key: str, *, collection: str | None = None) -> bool:
return await self._execute_with_circuit_breaker(self.key_value.delete, key=key, collection=collection)

@override
async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int:
return await self._execute_with_circuit_breaker(self.key_value.delete_many, keys=keys, collection=collection)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from key_value.aio.wrappers.rate_limit.wrapper import RateLimitWrapper

__all__ = ["RateLimitWrapper"]
Loading
Loading