diff --git a/README.md b/README.md index 2fe2db4e..152fbbd6 100644 --- a/README.md +++ b/README.md @@ -245,6 +245,7 @@ the protocol, your application code might be simplified by using an adapter. | Adapter | Description | Example | |---------|:------------|:------------------| +| DataclassAdapter | Type-safe storage/retrieval of dataclass models with transparent serialization/deserialization. | `DataclassAdapter(key_value=memory_store, dataclass_type=User)` | | PydanticAdapter | Type-safe storage/retrieval of Pydantic models with transparent serialization/deserialization. | `PydanticAdapter(key_value=memory_store, pydantic_model=User)` | | RaiseOnMissingAdapter | Optional raise-on-missing behavior for `get`, `get_many`, `ttl`, and `ttl_many`. | `RaiseOnMissingAdapter(key_value=memory_store)` | diff --git a/key-value/key-value-aio/src/key_value/aio/adapters/base/__init__.py b/key-value/key-value-aio/src/key_value/aio/adapters/base/__init__.py new file mode 100644 index 00000000..56b1edb1 --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/adapters/base/__init__.py @@ -0,0 +1,3 @@ +from key_value.aio.adapters.pydantic.base import BasePydanticAdapter + +__all__ = ["BasePydanticAdapter"] diff --git a/key-value/key-value-aio/src/key_value/aio/adapters/dataclass/__init__.py b/key-value/key-value-aio/src/key_value/aio/adapters/dataclass/__init__.py new file mode 100644 index 00000000..e45b4ea4 --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/adapters/dataclass/__init__.py @@ -0,0 +1,3 @@ +from key_value.aio.adapters.dataclass.adapter import DataclassAdapter + +__all__ = ["DataclassAdapter"] diff --git a/key-value/key-value-aio/src/key_value/aio/adapters/dataclass/adapter.py b/key-value/key-value-aio/src/key_value/aio/adapters/dataclass/adapter.py new file mode 100644 index 00000000..bbd1af01 --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/adapters/dataclass/adapter.py @@ -0,0 +1,71 @@ +from collections.abc import Sequence +from dataclasses import is_dataclass +from typing import Any, TypeVar, get_args, get_origin + +from key_value.shared.type_checking.bear_spray import bear_spray +from pydantic.type_adapter import TypeAdapter + +from key_value.aio.adapters.pydantic.base import BasePydanticAdapter +from key_value.aio.protocols.key_value import AsyncKeyValue + +T = TypeVar("T") + + +class DataclassAdapter(BasePydanticAdapter[T]): + """Adapter around a KVStore-compliant Store that allows type-safe persistence of dataclasses. + + This adapter works with both standard library dataclasses and Pydantic dataclasses, + leveraging Pydantic's TypeAdapter for robust validation and serialization. + """ + + _inner_type: type[Any] + + # Beartype cannot handle the parameterized type annotation (type[T]) used here for this generic dataclass adapter. + # Using @bear_spray to bypass beartype's runtime checks for this specific method. + @bear_spray + def __init__( + self, + key_value: AsyncKeyValue, + dataclass_type: type[T], + default_collection: str | None = None, + raise_on_validation_error: bool = False, + ) -> None: + """Create a new DataclassAdapter. + + Args: + key_value: The AsyncKeyValue to use. + dataclass_type: The dataclass type to use. Can be a single dataclass or list[dataclass]. + default_collection: The default collection to use. + raise_on_validation_error: Whether to raise a DeserializationError if validation fails during reads. Otherwise, + calls will return None if validation fails. + + Raises: + TypeError: If dataclass_type is not a dataclass type. + """ + self._key_value = key_value + + origin = get_origin(dataclass_type) + self._is_list_model = origin is not None and isinstance(origin, type) and issubclass(origin, Sequence) + + # Extract the inner type for list models + if self._is_list_model: + args = get_args(dataclass_type) + if not args: + msg = f"List type {dataclass_type} must have a type argument" + raise TypeError(msg) + self._inner_type = args[0] + else: + self._inner_type = dataclass_type + + # Validate that the inner type is a dataclass + if not is_dataclass(self._inner_type): + msg = f"{self._inner_type} is not a dataclass" + raise TypeError(msg) + + self._type_adapter = TypeAdapter[T](dataclass_type) + self._default_collection = default_collection + self._raise_on_validation_error = raise_on_validation_error + + def _get_model_type_name(self) -> str: + """Return the model type name for error messages.""" + return "dataclass" diff --git a/key-value/key-value-aio/src/key_value/aio/adapters/pydantic/adapter.py b/key-value/key-value-aio/src/key_value/aio/adapters/pydantic/adapter.py index 8697f2c3..ec1afdd9 100644 --- a/key-value/key-value-aio/src/key_value/aio/adapters/pydantic/adapter.py +++ b/key-value/key-value-aio/src/key_value/aio/adapters/pydantic/adapter.py @@ -1,28 +1,21 @@ from collections.abc import Sequence -from typing import Any, Generic, SupportsFloat, TypeVar, get_origin, overload +from typing import TypeVar, get_origin -from key_value.shared.errors import DeserializationError, SerializationError from key_value.shared.type_checking.bear_spray import bear_spray -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from pydantic.type_adapter import TypeAdapter -from pydantic_core import PydanticSerializationError +from key_value.aio.adapters.pydantic.base import BasePydanticAdapter from key_value.aio.protocols.key_value import AsyncKeyValue T = TypeVar("T", bound=BaseModel | Sequence[BaseModel]) -class PydanticAdapter(Generic[T]): +class PydanticAdapter(BasePydanticAdapter[T]): """Adapter around a KVStore-compliant Store that allows type-safe persistence of Pydantic models.""" - _key_value: AsyncKeyValue - _is_list_model: bool - _type_adapter: TypeAdapter[T] - _default_collection: str | None - _raise_on_validation_error: bool - - # Beartype doesn't like our `type[T] includes a bound on Sequence[...] as the subscript is not checkable at runtime - # For just the next 20 or so lines we are no longer bear bros but have no fear, we will be back soon! + # Beartype cannot handle the parameterized type annotation (type[T]) used here for this generic adapter. + # Using @bear_spray to bypass beartype's runtime checks for this specific method. @bear_spray def __init__( self, @@ -35,212 +28,28 @@ def __init__( Args: key_value: The KVStore to use. - pydantic_model: The Pydantic model to use. + pydantic_model: The Pydantic model to use. Can be a single Pydantic model or list[Pydantic model]. default_collection: The default collection to use. - raise_on_validation_error: Whether to raise a ValidationError if the model is invalid. - """ + raise_on_validation_error: Whether to raise a DeserializationError if validation fails during reads. Otherwise, + calls will return None if validation fails. + Raises: + TypeError: If pydantic_model is a sequence type other than list (e.g., tuple is not supported). + """ self._key_value = key_value origin = get_origin(pydantic_model) - self._is_list_model = origin is not None and isinstance(origin, type) and issubclass(origin, Sequence) + self._is_list_model = origin is list + + # Validate that if it's a generic type, it must be a list (not tuple, etc.) + if origin is not None and origin is not list: + msg = f"Only list[BaseModel] is supported for sequence types, got {pydantic_model}" + raise TypeError(msg) self._type_adapter = TypeAdapter[T](pydantic_model) self._default_collection = default_collection self._raise_on_validation_error = raise_on_validation_error - def _validate_model(self, value: dict[str, Any]) -> T | None: - """Validate and deserialize a dict into the configured Pydantic model. - - This method handles both single models and list models. For list models, it expects the value - to contain an "items" key with the list data, following the convention used by `_serialize_model`. - If validation fails and `raise_on_validation_error` is False, returns None instead of raising. - - Args: - value: The dict to validate and convert to a Pydantic model. - - Returns: - The validated model instance, or None if validation fails and errors are suppressed. - - Raises: - DeserializationError: If validation fails and `raise_on_validation_error` is True. - """ - try: - if self._is_list_model: - return self._type_adapter.validate_python(value.get("items", [])) - - return self._type_adapter.validate_python(value) - except ValidationError as e: - if self._raise_on_validation_error: - msg = f"Invalid Pydantic model: {value}" - raise DeserializationError(msg) from e - return None - - def _serialize_model(self, value: T) -> dict[str, Any]: - """Serialize a Pydantic model to a dict for storage. - - This method handles both single models and list models. For list models, it wraps the serialized - list in a dict with an "items" key (e.g., {"items": [...]}) to ensure consistent dict-based storage - format across all value types. This wrapping convention is expected by `_validate_model` during - deserialization. - - Args: - value: The Pydantic model instance to serialize. - - Returns: - A dict representation of the model suitable for storage. - - Raises: - SerializationError: If the model cannot be serialized. - """ - try: - if self._is_list_model: - return {"items": self._type_adapter.dump_python(value, mode="json")} - - return self._type_adapter.dump_python(value, mode="json") # pyright: ignore[reportAny] - except PydanticSerializationError as e: - msg = f"Invalid Pydantic model: {e}" - raise SerializationError(msg) from e - - @overload - async def get(self, key: str, *, collection: str | None = None, default: T) -> T: ... - - @overload - async def get(self, key: str, *, collection: str | None = None, default: None = None) -> T | None: ... - - async def get(self, key: str, *, collection: str | None = None, default: T | None = None) -> T | None: - """Get and validate a model by key. - - Args: - key: The key to retrieve. - collection: The collection to use. If not provided, uses the default collection. - default: The default value to return if the key doesn't exist or validation fails. - - Returns: - The parsed model instance if found and valid, or the default value if key doesn't exist or validation fails. - - Raises: - DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to - raise on validation error. - - Note: - When raise_on_validation_error=False and validation fails, returns the default value (which may be None). - When raise_on_validation_error=True and validation fails, raises DeserializationError. - """ - collection = collection or self._default_collection - - if value := await self._key_value.get(key=key, collection=collection): - validated = self._validate_model(value=value) - if validated is not None: - return validated - - return default - - @overload - async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T) -> list[T]: ... - - @overload - async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: None = None) -> list[T | None]: ... - - async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T | None = None) -> list[T] | list[T | None]: - """Batch get and validate models by keys, preserving order. - - Args: - keys: The list of keys to retrieve. - collection: The collection to use. If not provided, uses the default collection. - default: The default value to return for keys that don't exist or fail validation. - - Returns: - A list of parsed model instances, with default values for missing keys or validation failures. - - Raises: - DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to - raise on validation error. - - Note: - When raise_on_validation_error=False and validation fails for any key, that position in the returned list - will contain the default value (which may be None). The method returns a complete list matching the order - and length of the input keys, with defaults substituted for missing or invalid entries. - """ - collection = collection or self._default_collection - - values: list[dict[str, Any] | None] = await self._key_value.get_many(keys=keys, collection=collection) - - result: list[T | None] = [] - for value in values: - if value is None: - result.append(default) - else: - validated = self._validate_model(value=value) - result.append(validated if validated is not None else default) - return result - - async def put(self, key: str, value: T, *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: - """Serialize and store a model. - - Propagates SerializationError if the model cannot be serialized. - """ - collection = collection or self._default_collection - - value_dict: dict[str, Any] = self._serialize_model(value=value) - - await self._key_value.put(key=key, value=value_dict, collection=collection, ttl=ttl) - - async def put_many( - self, keys: Sequence[str], values: Sequence[T], *, collection: str | None = None, ttl: SupportsFloat | None = None - ) -> None: - """Serialize and store multiple models, preserving order alignment with keys.""" - collection = collection or self._default_collection - - value_dicts: list[dict[str, Any]] = [self._serialize_model(value=value) for value in values] - - await self._key_value.put_many(keys=keys, values=value_dicts, collection=collection, ttl=ttl) - - async def delete(self, key: str, *, collection: str | None = None) -> bool: - """Delete a model by key. Returns True if a value was deleted, else False.""" - collection = collection or self._default_collection - - return await self._key_value.delete(key=key, collection=collection) - - async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int: - """Delete multiple models by key. Returns the count of deleted entries.""" - collection = collection or self._default_collection - - return await self._key_value.delete_many(keys=keys, collection=collection) - - async def ttl(self, key: str, *, collection: str | None = None) -> tuple[T | None, float | None]: - """Get a model and its TTL seconds if present. - - Args: - key: The key to retrieve. - collection: The collection to use. If not provided, uses the default collection. - - Returns: - A tuple of (model, ttl_seconds). Returns (None, None) if the key is missing or validation fails. - - Note: - When validation fails and raise_on_validation_error=False, returns (None, None) even if TTL data exists. - When validation fails and raise_on_validation_error=True, raises DeserializationError. - """ - collection = collection or self._default_collection - - entry: dict[str, Any] | None - ttl_info: float | None - - entry, ttl_info = await self._key_value.ttl(key=key, collection=collection) - - if entry is None: - return (None, None) - - if validated_model := self._validate_model(value=entry): - return (validated_model, ttl_info) - - return (None, None) - - async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[T | None, float | None]]: - """Batch get models with TTLs. Each element is (model|None, ttl_seconds|None).""" - collection = collection or self._default_collection - - entries: list[tuple[dict[str, Any] | None, float | None]] = await self._key_value.ttl_many(keys=keys, collection=collection) - - return [(self._validate_model(value=entry) if entry else None, ttl_info) for entry, ttl_info in entries] + def _get_model_type_name(self) -> str: + """Return the model type name for error messages.""" + return "Pydantic model" diff --git a/key-value/key-value-aio/src/key_value/aio/adapters/pydantic/base.py b/key-value/key-value-aio/src/key_value/aio/adapters/pydantic/base.py new file mode 100644 index 00000000..509a6866 --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/adapters/pydantic/base.py @@ -0,0 +1,240 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Any, Generic, SupportsFloat, TypeVar, overload + +from key_value.shared.errors import DeserializationError, SerializationError +from pydantic import ValidationError +from pydantic.type_adapter import TypeAdapter +from pydantic_core import PydanticSerializationError + +from key_value.aio.protocols.key_value import AsyncKeyValue + +T = TypeVar("T") + + +class BasePydanticAdapter(Generic[T], ABC): + """Base adapter using Pydantic TypeAdapter for validation and serialization. + + This abstract base class provides shared functionality for adapters that use + Pydantic's TypeAdapter for validation and serialization. Concrete subclasses + must implement _get_model_type_name() to provide appropriate error messages. + """ + + _key_value: AsyncKeyValue + _is_list_model: bool + _type_adapter: TypeAdapter[T] + _default_collection: str | None + _raise_on_validation_error: bool + + @abstractmethod + def _get_model_type_name(self) -> str: + """Return the model type name for error messages. + + Returns: + A string describing the model type (e.g., "Pydantic model", "dataclass"). + """ + ... + + def _validate_model(self, value: dict[str, Any]) -> T | None: + """Validate and deserialize a dict into the configured model type. + + This method handles both single models and list models. For list models, it expects the value + to contain an "items" key with the list data, following the convention used by `_serialize_model`. + If validation fails and `raise_on_validation_error` is False, returns None instead of raising. + + Args: + value: The dict to validate and convert to a model. + + Returns: + The validated model instance, or None if validation fails and errors are suppressed. + + Raises: + DeserializationError: If validation fails and `raise_on_validation_error` is True. + """ + try: + if self._is_list_model: + if "items" not in value: + if self._raise_on_validation_error: + msg = f"Invalid {self._get_model_type_name()} payload: missing 'items' wrapper" + raise DeserializationError(msg) + return None + return self._type_adapter.validate_python(value["items"]) + + return self._type_adapter.validate_python(value) + except ValidationError as e: + if self._raise_on_validation_error: + details = e.errors(include_input=False) + msg = f"Invalid {self._get_model_type_name()}: {details}" + raise DeserializationError(msg) from e + return None + + def _serialize_model(self, value: T) -> dict[str, Any]: + """Serialize a model to a dict for storage. + + This method handles both single models and list models. For list models, it wraps the serialized + list in a dict with an "items" key (e.g., {"items": [...]}) to ensure consistent dict-based storage + format across all value types. This wrapping convention is expected by `_validate_model` during + deserialization. + + Args: + value: The model instance to serialize. + + Returns: + A dict representation of the model suitable for storage. + + Raises: + SerializationError: If the model cannot be serialized. + """ + try: + if self._is_list_model: + return {"items": self._type_adapter.dump_python(value, mode="json")} + + return self._type_adapter.dump_python(value, mode="json") # pyright: ignore[reportAny] + except PydanticSerializationError as e: + msg = f"Invalid {self._get_model_type_name()}: {e}" + raise SerializationError(msg) from e + + @overload + async def get(self, key: str, *, collection: str | None = None, default: T) -> T: ... + + @overload + async def get(self, key: str, *, collection: str | None = None, default: None = None) -> T | None: ... + + async def get(self, key: str, *, collection: str | None = None, default: T | None = None) -> T | None: + """Get and validate a model by key. + + Args: + key: The key to retrieve. + collection: The collection to use. If not provided, uses the default collection. + default: The default value to return if the key doesn't exist or validation fails. + + Returns: + The parsed model instance if found and valid, or the default value if key doesn't exist or validation fails. + + Raises: + DeserializationError: If the stored data cannot be validated as the model and the adapter is configured to + raise on validation error. + + Note: + When raise_on_validation_error=False and validation fails, returns the default value (which may be None). + When raise_on_validation_error=True and validation fails, raises DeserializationError. + """ + collection = collection or self._default_collection + + value = await self._key_value.get(key=key, collection=collection) + if value is not None: + validated = self._validate_model(value=value) + if validated is not None: + return validated + + return default + + @overload + async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T) -> list[T]: ... + + @overload + async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: None = None) -> list[T | None]: ... + + async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T | None = None) -> list[T] | list[T | None]: + """Batch get and validate models by keys, preserving order. + + Args: + keys: The list of keys to retrieve. + collection: The collection to use. If not provided, uses the default collection. + default: The default value to return for keys that don't exist or fail validation. + + Returns: + A list of parsed model instances, with default values for missing keys or validation failures. + + Raises: + DeserializationError: If the stored data cannot be validated as the model and the adapter is configured to + raise on validation error. + + Note: + When raise_on_validation_error=False and validation fails for any key, that position in the returned list + will contain the default value (which may be None). The method returns a complete list matching the order + and length of the input keys, with defaults substituted for missing or invalid entries. + """ + collection = collection or self._default_collection + + values: list[dict[str, Any] | None] = await self._key_value.get_many(keys=keys, collection=collection) + + result: list[T | None] = [] + for value in values: + if value is None: + result.append(default) + else: + validated = self._validate_model(value=value) + result.append(validated if validated is not None else default) + return result + + async def put(self, key: str, value: T, *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: + """Serialize and store a model. + + Propagates SerializationError if the model cannot be serialized. + """ + collection = collection or self._default_collection + + value_dict: dict[str, Any] = self._serialize_model(value=value) + + await self._key_value.put(key=key, value=value_dict, collection=collection, ttl=ttl) + + async def put_many( + self, keys: Sequence[str], values: Sequence[T], *, collection: str | None = None, ttl: SupportsFloat | None = None + ) -> None: + """Serialize and store multiple models, preserving order alignment with keys.""" + collection = collection or self._default_collection + + value_dicts: list[dict[str, Any]] = [self._serialize_model(value=value) for value in values] + + await self._key_value.put_many(keys=keys, values=value_dicts, collection=collection, ttl=ttl) + + async def delete(self, key: str, *, collection: str | None = None) -> bool: + """Delete a model by key. Returns True if a value was deleted, else False.""" + collection = collection or self._default_collection + + return await self._key_value.delete(key=key, collection=collection) + + async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int: + """Delete multiple models by key. Returns the count of deleted entries.""" + collection = collection or self._default_collection + + return await self._key_value.delete_many(keys=keys, collection=collection) + + async def ttl(self, key: str, *, collection: str | None = None) -> tuple[T | None, float | None]: + """Get a model and its TTL seconds if present. + + Args: + key: The key to retrieve. + collection: The collection to use. If not provided, uses the default collection. + + Returns: + A tuple of (model, ttl_seconds). Returns (None, None) if the key is missing or validation fails. + + Note: + When validation fails and raise_on_validation_error=False, returns (None, None) even if TTL data exists. + When validation fails and raise_on_validation_error=True, raises DeserializationError. + """ + collection = collection or self._default_collection + + entry: dict[str, Any] | None + ttl_info: float | None + + entry, ttl_info = await self._key_value.ttl(key=key, collection=collection) + + if entry is None: + return (None, None) + + validated_model = self._validate_model(value=entry) + if validated_model is not None: + return (validated_model, ttl_info) + + return (None, None) + + async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[T | None, float | None]]: + """Batch get models with TTLs. Each element is (model|None, ttl_seconds|None).""" + collection = collection or self._default_collection + + entries: list[tuple[dict[str, Any] | None, float | None]] = await self._key_value.ttl_many(keys=keys, collection=collection) + + return [(self._validate_model(value=entry) if entry is not None else None, ttl_info) for entry, ttl_info in entries] diff --git a/key-value/key-value-aio/src/key_value/aio/protocols/key_value.py b/key-value/key-value-aio/src/key_value/aio/protocols/key_value.py index e3cb9624..5e8a1426 100644 --- a/key-value/key-value-aio/src/key_value/aio/protocols/key_value.py +++ b/key-value/key-value-aio/src/key_value/aio/protocols/key_value.py @@ -6,8 +6,9 @@ class AsyncKeyValueProtocol(Protocol): """A subset of KV operations: get/put/delete and TTL variants, including bulk calls. - This protocol defines the minimal contract for key-value store implementations. All methods are - async and may raise exceptions on connection failures, validation errors, or other operational issues. + This protocol defines the minimal contract for key-value store implementations. All methods may + raise exceptions on connection failures, validation errors, or other operational issues. + Implementations should handle backend-specific errors appropriately. """ diff --git a/key-value/key-value-aio/tests/adapters/test_dataclass.py b/key-value/key-value-aio/tests/adapters/test_dataclass.py new file mode 100644 index 00000000..352edc61 --- /dev/null +++ b/key-value/key-value-aio/tests/adapters/test_dataclass.py @@ -0,0 +1,253 @@ +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any + +import pytest +from inline_snapshot import snapshot +from key_value.shared.errors import DeserializationError + +from key_value.aio.adapters.dataclass import DataclassAdapter +from key_value.aio.stores.memory.store import MemoryStore + + +@dataclass +class User: + name: str + age: int + email: str + + +@dataclass +class UpdatedUser: + name: str + age: int + email: str + is_admin: bool + + +@dataclass +class Product: + name: str + price: float + quantity: int + + +@dataclass +class Address: + street: str + city: str + zip_code: str + + +@dataclass +class UserWithAddress: + name: str + age: int + address: Address + + +@dataclass +class Order: + created_at: datetime + updated_at: datetime + user: User + product: Product + paid: bool = False + + +FIXED_CREATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=12, minute=0, second=0, tzinfo=timezone.utc) +FIXED_UPDATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=15, minute=0, second=0, tzinfo=timezone.utc) + +SAMPLE_USER: User = User(name="John Doe", email="john.doe@example.com", age=30) +SAMPLE_USER_2: User = User(name="Jane Doe", email="jane.doe@example.com", age=25) +SAMPLE_PRODUCT: Product = Product(name="Widget", price=29.99, quantity=10) +SAMPLE_ADDRESS: Address = Address(street="123 Main St", city="Springfield", zip_code="12345") +SAMPLE_USER_WITH_ADDRESS: UserWithAddress = UserWithAddress(name="John Doe", age=30, address=SAMPLE_ADDRESS) +SAMPLE_ORDER: Order = Order(created_at=FIXED_CREATED_AT, updated_at=FIXED_UPDATED_AT, user=SAMPLE_USER, product=SAMPLE_PRODUCT, paid=False) + +TEST_COLLECTION: str = "test_collection" +TEST_KEY: str = "test_key" +TEST_KEY_2: str = "test_key_2" + + +class TestDataclassAdapter: + @pytest.fixture + async def store(self) -> MemoryStore: + return MemoryStore() + + @pytest.fixture + async def user_adapter(self, store: MemoryStore) -> DataclassAdapter[User]: + return DataclassAdapter[User](key_value=store, dataclass_type=User) + + @pytest.fixture + async def updated_user_adapter(self, store: MemoryStore) -> DataclassAdapter[UpdatedUser]: + return DataclassAdapter[UpdatedUser](key_value=store, dataclass_type=UpdatedUser) + + @pytest.fixture + async def product_adapter(self, store: MemoryStore) -> DataclassAdapter[Product]: + return DataclassAdapter[Product](key_value=store, dataclass_type=Product) + + @pytest.fixture + async def product_list_adapter(self, store: MemoryStore) -> DataclassAdapter[list[Product]]: + return DataclassAdapter[list[Product]](key_value=store, dataclass_type=list[Product]) + + @pytest.fixture + async def user_with_address_adapter(self, store: MemoryStore) -> DataclassAdapter[UserWithAddress]: + return DataclassAdapter[UserWithAddress](key_value=store, dataclass_type=UserWithAddress) + + @pytest.fixture + async def order_adapter(self, store: MemoryStore) -> DataclassAdapter[Order]: + return DataclassAdapter[Order](key_value=store, dataclass_type=Order) + + async def test_simple_adapter(self, user_adapter: DataclassAdapter[User]): + """Test basic put/get/delete operations with a simple dataclass.""" + await user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) + cached_user: User | None = await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + assert cached_user == SAMPLE_USER + + assert await user_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) + + assert await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + + async def test_simple_adapter_with_default(self, user_adapter: DataclassAdapter[User]): + """Test default value handling.""" + assert await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER + + await user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_2) + assert await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER_2 + + assert await user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot( + [SAMPLE_USER_2, SAMPLE_USER] + ) + + async def test_simple_adapter_with_validation_error_ignore( + self, user_adapter: DataclassAdapter[User], updated_user_adapter: DataclassAdapter[UpdatedUser] + ): + """Test that validation errors return None when raise_on_validation_error is False.""" + await user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) + + # UpdatedUser requires is_admin field which doesn't exist in stored User + updated_user = await updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + assert updated_user is None + + async def test_simple_adapter_with_validation_error_raise( + self, user_adapter: DataclassAdapter[User], updated_user_adapter: DataclassAdapter[UpdatedUser] + ): + """Test that validation errors raise DeserializationError when raise_on_validation_error is True.""" + await user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) + updated_user_adapter._raise_on_validation_error = True # pyright: ignore[reportPrivateUsage] + with pytest.raises(DeserializationError): + await updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + + async def test_nested_dataclass(self, user_with_address_adapter: DataclassAdapter[UserWithAddress]): + """Test that nested dataclasses are properly serialized and deserialized.""" + await user_with_address_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_WITH_ADDRESS) + cached_user: UserWithAddress | None = await user_with_address_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + assert cached_user == SAMPLE_USER_WITH_ADDRESS + assert cached_user is not None + assert cached_user.address.street == "123 Main St" + + async def test_complex_adapter(self, order_adapter: DataclassAdapter[Order]): + """Test complex dataclass with nested objects and TTL.""" + await order_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_ORDER, ttl=10) + cached_order: Order | None = await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + assert cached_order == SAMPLE_ORDER + + assert await order_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) + assert await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + + async def test_complex_adapter_with_list(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore): + """Test list dataclass serialization with proper wrapping.""" + await product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT], ttl=10) + cached_products: list[Product] | None = await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + assert cached_products == [SAMPLE_PRODUCT, SAMPLE_PRODUCT] + + # We need to ensure our memory store doesn't hold an entry with an array + raw_collection = store._cache.get(TEST_COLLECTION) # pyright: ignore[reportPrivateUsage] + assert raw_collection is not None + + raw_entry = raw_collection.get(TEST_KEY) + assert raw_entry is not None + assert isinstance(raw_entry.value, dict) + assert raw_entry.value == snapshot( + {"items": [{"name": "Widget", "price": 29.99, "quantity": 10}, {"name": "Widget", "price": 29.99, "quantity": 10}]} + ) + + assert await product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) + assert await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + + async def test_batch_operations(self, user_adapter: DataclassAdapter[User]): + """Test batch put/get/delete operations.""" + keys = [TEST_KEY, TEST_KEY_2] + users = [SAMPLE_USER, SAMPLE_USER_2] + + # Test put_many + await user_adapter.put_many(collection=TEST_COLLECTION, keys=keys, values=users) + + # Test get_many + cached_users = await user_adapter.get_many(collection=TEST_COLLECTION, keys=keys) + assert cached_users == users + + # Test delete_many + deleted_count = await user_adapter.delete_many(collection=TEST_COLLECTION, keys=keys) + assert deleted_count == 2 + + # Verify deletion + cached_users_after_delete = await user_adapter.get_many(collection=TEST_COLLECTION, keys=keys) + assert cached_users_after_delete == [None, None] + + async def test_ttl_operations(self, user_adapter: DataclassAdapter[User]): + """Test TTL-related operations.""" + # Test single TTL + await user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER, ttl=10) + user, ttl = await user_adapter.ttl(collection=TEST_COLLECTION, key=TEST_KEY) + assert user == SAMPLE_USER + assert ttl is not None + assert ttl > 0 + + # Test ttl_many + await user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY_2, value=SAMPLE_USER_2, ttl=20) + ttl_results = await user_adapter.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2]) + assert len(ttl_results) == 2 + assert ttl_results[0][0] == SAMPLE_USER + assert ttl_results[1][0] == SAMPLE_USER_2 + + async def test_dataclass_validation_on_init(self, store: MemoryStore): + """Test that non-dataclass types are rejected.""" + with pytest.raises(TypeError, match="is not a dataclass"): + DataclassAdapter[str](key_value=store, dataclass_type=str) # type: ignore[type-var] + + async def test_default_collection(self, store: MemoryStore): + """Test that default collection is used when not specified.""" + adapter = DataclassAdapter[User](key_value=store, dataclass_type=User, default_collection=TEST_COLLECTION) + + await adapter.put(key=TEST_KEY, value=SAMPLE_USER) + cached_user = await adapter.get(key=TEST_KEY) + assert cached_user == SAMPLE_USER + + assert await adapter.delete(key=TEST_KEY) + + async def test_ttl_with_empty_list(self, product_list_adapter: DataclassAdapter[list[Product]]): + """Test that TTL with empty list returns correctly (not None).""" + await product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[], ttl=10) + value, ttl = await product_list_adapter.ttl(collection=TEST_COLLECTION, key=TEST_KEY) + assert value == [] + assert ttl is not None + assert ttl > 0 + + async def test_list_payload_missing_items_returns_none(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore): + """Test that list payload without 'items' wrapper returns None when raise_on_validation_error is False.""" + # Manually insert malformed payload without the 'items' wrapper + # The payload is a dict but without the expected 'items' key for list models + malformed_payload: dict[str, Any] = {"wrong": []} + await store.put(collection=TEST_COLLECTION, key=TEST_KEY, value=malformed_payload) + assert await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + + async def test_list_payload_missing_items_raises(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore): + """Test that list payload without 'items' wrapper raises DeserializationError when configured.""" + product_list_adapter._raise_on_validation_error = True # pyright: ignore[reportPrivateUsage] + # Manually insert malformed payload without the 'items' wrapper + malformed_payload: dict[str, Any] = {"wrong": []} + await store.put(collection=TEST_COLLECTION, key=TEST_KEY, value=malformed_payload) + with pytest.raises(DeserializationError, match="missing 'items'"): + await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) diff --git a/key-value/key-value-aio/tests/adapters/test_pydantic.py b/key-value/key-value-aio/tests/adapters/test_pydantic.py index 6d7cdaef..617786c1 100644 --- a/key-value/key-value-aio/tests/adapters/test_pydantic.py +++ b/key-value/key-value-aio/tests/adapters/test_pydantic.py @@ -91,6 +91,14 @@ async def test_simple_adapter_with_default(self, user_adapter: PydanticAdapter[U [SAMPLE_USER_2, SAMPLE_USER] ) + async def test_simple_adapter_with_list(self, product_list_adapter: PydanticAdapter[list[Product]]): + await product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT]) + cached_products: list[Product] | None = await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + assert cached_products == [SAMPLE_PRODUCT, SAMPLE_PRODUCT] + + assert await product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) + assert await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + async def test_simple_adapter_with_validation_error_ignore( self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser] ): diff --git a/key-value/key-value-sync/src/key_value/sync/adapters/base/__init__.py b/key-value/key-value-sync/src/key_value/sync/adapters/base/__init__.py new file mode 100644 index 00000000..2fc794fa --- /dev/null +++ b/key-value/key-value-sync/src/key_value/sync/adapters/base/__init__.py @@ -0,0 +1,6 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file '__init__.py' +# DO NOT CHANGE! Change the original file instead. +from key_value.sync.code_gen.adapters.pydantic.base import BasePydanticAdapter + +__all__ = ["BasePydanticAdapter"] diff --git a/key-value/key-value-sync/src/key_value/sync/adapters/dataclass/__init__.py b/key-value/key-value-sync/src/key_value/sync/adapters/dataclass/__init__.py new file mode 100644 index 00000000..7583b7de --- /dev/null +++ b/key-value/key-value-sync/src/key_value/sync/adapters/dataclass/__init__.py @@ -0,0 +1,6 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file '__init__.py' +# DO NOT CHANGE! Change the original file instead. +from key_value.sync.code_gen.adapters.dataclass.adapter import DataclassAdapter + +__all__ = ["DataclassAdapter"] diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/base/__init__.py b/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/base/__init__.py new file mode 100644 index 00000000..2fc794fa --- /dev/null +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/base/__init__.py @@ -0,0 +1,6 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file '__init__.py' +# DO NOT CHANGE! Change the original file instead. +from key_value.sync.code_gen.adapters.pydantic.base import BasePydanticAdapter + +__all__ = ["BasePydanticAdapter"] diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/dataclass/__init__.py b/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/dataclass/__init__.py new file mode 100644 index 00000000..7583b7de --- /dev/null +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/dataclass/__init__.py @@ -0,0 +1,6 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file '__init__.py' +# DO NOT CHANGE! Change the original file instead. +from key_value.sync.code_gen.adapters.dataclass.adapter import DataclassAdapter + +__all__ = ["DataclassAdapter"] diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/dataclass/adapter.py b/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/dataclass/adapter.py new file mode 100644 index 00000000..991751a7 --- /dev/null +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/dataclass/adapter.py @@ -0,0 +1,71 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file 'adapter.py' +# DO NOT CHANGE! Change the original file instead. +from collections.abc import Sequence +from dataclasses import is_dataclass +from typing import Any, TypeVar, get_args, get_origin + +from key_value.shared.type_checking.bear_spray import bear_spray +from pydantic.type_adapter import TypeAdapter + +from key_value.sync.code_gen.adapters.pydantic.base import BasePydanticAdapter +from key_value.sync.code_gen.protocols.key_value import KeyValue + +T = TypeVar("T") + + +class DataclassAdapter(BasePydanticAdapter[T]): + """Adapter around a KVStore-compliant Store that allows type-safe persistence of dataclasses. + + This adapter works with both standard library dataclasses and Pydantic dataclasses, + leveraging Pydantic's TypeAdapter for robust validation and serialization. + """ + + _inner_type: type[Any] + + # Beartype cannot handle the parameterized type annotation (type[T]) used here for this generic dataclass adapter. + # Using @bear_spray to bypass beartype's runtime checks for this specific method. + + @bear_spray + def __init__( + self, key_value: KeyValue, dataclass_type: type[T], default_collection: str | None = None, raise_on_validation_error: bool = False + ) -> None: + """Create a new DataclassAdapter. + + Args: + key_value: The KeyValue to use. + dataclass_type: The dataclass type to use. Can be a single dataclass or list[dataclass]. + default_collection: The default collection to use. + raise_on_validation_error: Whether to raise a DeserializationError if validation fails during reads. Otherwise, + calls will return None if validation fails. + + Raises: + TypeError: If dataclass_type is not a dataclass type. + """ + self._key_value = key_value + + origin = get_origin(dataclass_type) + self._is_list_model = origin is not None and isinstance(origin, type) and issubclass(origin, Sequence) + + # Extract the inner type for list models + if self._is_list_model: + args = get_args(dataclass_type) + if not args: + msg = f"List type {dataclass_type} must have a type argument" + raise TypeError(msg) + self._inner_type = args[0] + else: + self._inner_type = dataclass_type + + # Validate that the inner type is a dataclass + if not is_dataclass(self._inner_type): + msg = f"{self._inner_type} is not a dataclass" + raise TypeError(msg) + + self._type_adapter = TypeAdapter[T](dataclass_type) + self._default_collection = default_collection + self._raise_on_validation_error = raise_on_validation_error + + def _get_model_type_name(self) -> str: + """Return the model type name for error messages.""" + return "dataclass" diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/pydantic/adapter.py b/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/pydantic/adapter.py index 8030ecd0..a1973de9 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/pydantic/adapter.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/pydantic/adapter.py @@ -2,30 +2,23 @@ # from the original file 'adapter.py' # DO NOT CHANGE! Change the original file instead. from collections.abc import Sequence -from typing import Any, Generic, SupportsFloat, TypeVar, get_origin, overload +from typing import TypeVar, get_origin -from key_value.shared.errors import DeserializationError, SerializationError from key_value.shared.type_checking.bear_spray import bear_spray -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from pydantic.type_adapter import TypeAdapter -from pydantic_core import PydanticSerializationError +from key_value.sync.code_gen.adapters.pydantic.base import BasePydanticAdapter from key_value.sync.code_gen.protocols.key_value import KeyValue T = TypeVar("T", bound=BaseModel | Sequence[BaseModel]) -class PydanticAdapter(Generic[T]): +class PydanticAdapter(BasePydanticAdapter[T]): """Adapter around a KVStore-compliant Store that allows type-safe persistence of Pydantic models.""" - _key_value: KeyValue - _is_list_model: bool - _type_adapter: TypeAdapter[T] - _default_collection: str | None - _raise_on_validation_error: bool - - # Beartype doesn't like our `type[T] includes a bound on Sequence[...] as the subscript is not checkable at runtime - # For just the next 20 or so lines we are no longer bear bros but have no fear, we will be back soon! + # Beartype cannot handle the parameterized type annotation (type[T]) used here for this generic adapter. + # Using @bear_spray to bypass beartype's runtime checks for this specific method. @bear_spray def __init__( @@ -35,212 +28,28 @@ def __init__( Args: key_value: The KVStore to use. - pydantic_model: The Pydantic model to use. + pydantic_model: The Pydantic model to use. Can be a single Pydantic model or list[Pydantic model]. default_collection: The default collection to use. - raise_on_validation_error: Whether to raise a ValidationError if the model is invalid. - """ + raise_on_validation_error: Whether to raise a DeserializationError if validation fails during reads. Otherwise, + calls will return None if validation fails. + Raises: + TypeError: If pydantic_model is a sequence type other than list (e.g., tuple is not supported). + """ self._key_value = key_value origin = get_origin(pydantic_model) - self._is_list_model = origin is not None and isinstance(origin, type) and issubclass(origin, Sequence) + self._is_list_model = origin is list + + # Validate that if it's a generic type, it must be a list (not tuple, etc.) + if origin is not None and origin is not list: + msg = f"Only list[BaseModel] is supported for sequence types, got {pydantic_model}" + raise TypeError(msg) self._type_adapter = TypeAdapter[T](pydantic_model) self._default_collection = default_collection self._raise_on_validation_error = raise_on_validation_error - def _validate_model(self, value: dict[str, Any]) -> T | None: - """Validate and deserialize a dict into the configured Pydantic model. - - This method handles both single models and list models. For list models, it expects the value - to contain an "items" key with the list data, following the convention used by `_serialize_model`. - If validation fails and `raise_on_validation_error` is False, returns None instead of raising. - - Args: - value: The dict to validate and convert to a Pydantic model. - - Returns: - The validated model instance, or None if validation fails and errors are suppressed. - - Raises: - DeserializationError: If validation fails and `raise_on_validation_error` is True. - """ - try: - if self._is_list_model: - return self._type_adapter.validate_python(value.get("items", [])) - - return self._type_adapter.validate_python(value) - except ValidationError as e: - if self._raise_on_validation_error: - msg = f"Invalid Pydantic model: {value}" - raise DeserializationError(msg) from e - return None - - def _serialize_model(self, value: T) -> dict[str, Any]: - """Serialize a Pydantic model to a dict for storage. - - This method handles both single models and list models. For list models, it wraps the serialized - list in a dict with an "items" key (e.g., {"items": [...]}) to ensure consistent dict-based storage - format across all value types. This wrapping convention is expected by `_validate_model` during - deserialization. - - Args: - value: The Pydantic model instance to serialize. - - Returns: - A dict representation of the model suitable for storage. - - Raises: - SerializationError: If the model cannot be serialized. - """ - try: - if self._is_list_model: - return {"items": self._type_adapter.dump_python(value, mode="json")} - - return self._type_adapter.dump_python(value, mode="json") # pyright: ignore[reportAny] - except PydanticSerializationError as e: - msg = f"Invalid Pydantic model: {e}" - raise SerializationError(msg) from e - - @overload - def get(self, key: str, *, collection: str | None = None, default: T) -> T: ... - - @overload - def get(self, key: str, *, collection: str | None = None, default: None = None) -> T | None: ... - - def get(self, key: str, *, collection: str | None = None, default: T | None = None) -> T | None: - """Get and validate a model by key. - - Args: - key: The key to retrieve. - collection: The collection to use. If not provided, uses the default collection. - default: The default value to return if the key doesn't exist or validation fails. - - Returns: - The parsed model instance if found and valid, or the default value if key doesn't exist or validation fails. - - Raises: - DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to - raise on validation error. - - Note: - When raise_on_validation_error=False and validation fails, returns the default value (which may be None). - When raise_on_validation_error=True and validation fails, raises DeserializationError. - """ - collection = collection or self._default_collection - - if value := self._key_value.get(key=key, collection=collection): - validated = self._validate_model(value=value) - if validated is not None: - return validated - - return default - - @overload - def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T) -> list[T]: ... - - @overload - def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: None = None) -> list[T | None]: ... - - def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T | None = None) -> list[T] | list[T | None]: - """Batch get and validate models by keys, preserving order. - - Args: - keys: The list of keys to retrieve. - collection: The collection to use. If not provided, uses the default collection. - default: The default value to return for keys that don't exist or fail validation. - - Returns: - A list of parsed model instances, with default values for missing keys or validation failures. - - Raises: - DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to - raise on validation error. - - Note: - When raise_on_validation_error=False and validation fails for any key, that position in the returned list - will contain the default value (which may be None). The method returns a complete list matching the order - and length of the input keys, with defaults substituted for missing or invalid entries. - """ - collection = collection or self._default_collection - - values: list[dict[str, Any] | None] = self._key_value.get_many(keys=keys, collection=collection) - - result: list[T | None] = [] - for value in values: - if value is None: - result.append(default) - else: - validated = self._validate_model(value=value) - result.append(validated if validated is not None else default) - return result - - def put(self, key: str, value: T, *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: - """Serialize and store a model. - - Propagates SerializationError if the model cannot be serialized. - """ - collection = collection or self._default_collection - - value_dict: dict[str, Any] = self._serialize_model(value=value) - - self._key_value.put(key=key, value=value_dict, collection=collection, ttl=ttl) - - def put_many( - self, keys: Sequence[str], values: Sequence[T], *, collection: str | None = None, ttl: SupportsFloat | None = None - ) -> None: - """Serialize and store multiple models, preserving order alignment with keys.""" - collection = collection or self._default_collection - - value_dicts: list[dict[str, Any]] = [self._serialize_model(value=value) for value in values] - - self._key_value.put_many(keys=keys, values=value_dicts, collection=collection, ttl=ttl) - - def delete(self, key: str, *, collection: str | None = None) -> bool: - """Delete a model by key. Returns True if a value was deleted, else False.""" - collection = collection or self._default_collection - - return self._key_value.delete(key=key, collection=collection) - - def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int: - """Delete multiple models by key. Returns the count of deleted entries.""" - collection = collection or self._default_collection - - return self._key_value.delete_many(keys=keys, collection=collection) - - def ttl(self, key: str, *, collection: str | None = None) -> tuple[T | None, float | None]: - """Get a model and its TTL seconds if present. - - Args: - key: The key to retrieve. - collection: The collection to use. If not provided, uses the default collection. - - Returns: - A tuple of (model, ttl_seconds). Returns (None, None) if the key is missing or validation fails. - - Note: - When validation fails and raise_on_validation_error=False, returns (None, None) even if TTL data exists. - When validation fails and raise_on_validation_error=True, raises DeserializationError. - """ - collection = collection or self._default_collection - - entry: dict[str, Any] | None - ttl_info: float | None - - (entry, ttl_info) = self._key_value.ttl(key=key, collection=collection) - - if entry is None: - return (None, None) - - if validated_model := self._validate_model(value=entry): - return (validated_model, ttl_info) - - return (None, None) - - def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[T | None, float | None]]: - """Batch get models with TTLs. Each element is (model|None, ttl_seconds|None).""" - collection = collection or self._default_collection - - entries: list[tuple[dict[str, Any] | None, float | None]] = self._key_value.ttl_many(keys=keys, collection=collection) - - return [(self._validate_model(value=entry) if entry else None, ttl_info) for (entry, ttl_info) in entries] + def _get_model_type_name(self) -> str: + """Return the model type name for error messages.""" + return "Pydantic model" diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/pydantic/base.py b/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/pydantic/base.py new file mode 100644 index 00000000..f2801dba --- /dev/null +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/pydantic/base.py @@ -0,0 +1,243 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file 'base.py' +# DO NOT CHANGE! Change the original file instead. +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Any, Generic, SupportsFloat, TypeVar, overload + +from key_value.shared.errors import DeserializationError, SerializationError +from pydantic import ValidationError +from pydantic.type_adapter import TypeAdapter +from pydantic_core import PydanticSerializationError + +from key_value.sync.code_gen.protocols.key_value import KeyValue + +T = TypeVar("T") + + +class BasePydanticAdapter(Generic[T], ABC): + """Base adapter using Pydantic TypeAdapter for validation and serialization. + + This abstract base class provides shared functionality for adapters that use + Pydantic's TypeAdapter for validation and serialization. Concrete subclasses + must implement _get_model_type_name() to provide appropriate error messages. + """ + + _key_value: KeyValue + _is_list_model: bool + _type_adapter: TypeAdapter[T] + _default_collection: str | None + _raise_on_validation_error: bool + + @abstractmethod + def _get_model_type_name(self) -> str: + """Return the model type name for error messages. + + Returns: + A string describing the model type (e.g., "Pydantic model", "dataclass"). + """ + ... + + def _validate_model(self, value: dict[str, Any]) -> T | None: + """Validate and deserialize a dict into the configured model type. + + This method handles both single models and list models. For list models, it expects the value + to contain an "items" key with the list data, following the convention used by `_serialize_model`. + If validation fails and `raise_on_validation_error` is False, returns None instead of raising. + + Args: + value: The dict to validate and convert to a model. + + Returns: + The validated model instance, or None if validation fails and errors are suppressed. + + Raises: + DeserializationError: If validation fails and `raise_on_validation_error` is True. + """ + try: + if self._is_list_model: + if "items" not in value: + if self._raise_on_validation_error: + msg = f"Invalid {self._get_model_type_name()} payload: missing 'items' wrapper" + raise DeserializationError(msg) + return None + return self._type_adapter.validate_python(value["items"]) + + return self._type_adapter.validate_python(value) + except ValidationError as e: + if self._raise_on_validation_error: + details = e.errors(include_input=False) + msg = f"Invalid {self._get_model_type_name()}: {details}" + raise DeserializationError(msg) from e + return None + + def _serialize_model(self, value: T) -> dict[str, Any]: + """Serialize a model to a dict for storage. + + This method handles both single models and list models. For list models, it wraps the serialized + list in a dict with an "items" key (e.g., {"items": [...]}) to ensure consistent dict-based storage + format across all value types. This wrapping convention is expected by `_validate_model` during + deserialization. + + Args: + value: The model instance to serialize. + + Returns: + A dict representation of the model suitable for storage. + + Raises: + SerializationError: If the model cannot be serialized. + """ + try: + if self._is_list_model: + return {"items": self._type_adapter.dump_python(value, mode="json")} + + return self._type_adapter.dump_python(value, mode="json") # pyright: ignore[reportAny] + except PydanticSerializationError as e: + msg = f"Invalid {self._get_model_type_name()}: {e}" + raise SerializationError(msg) from e + + @overload + def get(self, key: str, *, collection: str | None = None, default: T) -> T: ... + + @overload + def get(self, key: str, *, collection: str | None = None, default: None = None) -> T | None: ... + + def get(self, key: str, *, collection: str | None = None, default: T | None = None) -> T | None: + """Get and validate a model by key. + + Args: + key: The key to retrieve. + collection: The collection to use. If not provided, uses the default collection. + default: The default value to return if the key doesn't exist or validation fails. + + Returns: + The parsed model instance if found and valid, or the default value if key doesn't exist or validation fails. + + Raises: + DeserializationError: If the stored data cannot be validated as the model and the adapter is configured to + raise on validation error. + + Note: + When raise_on_validation_error=False and validation fails, returns the default value (which may be None). + When raise_on_validation_error=True and validation fails, raises DeserializationError. + """ + collection = collection or self._default_collection + + value = self._key_value.get(key=key, collection=collection) + if value is not None: + validated = self._validate_model(value=value) + if validated is not None: + return validated + + return default + + @overload + def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T) -> list[T]: ... + + @overload + def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: None = None) -> list[T | None]: ... + + def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T | None = None) -> list[T] | list[T | None]: + """Batch get and validate models by keys, preserving order. + + Args: + keys: The list of keys to retrieve. + collection: The collection to use. If not provided, uses the default collection. + default: The default value to return for keys that don't exist or fail validation. + + Returns: + A list of parsed model instances, with default values for missing keys or validation failures. + + Raises: + DeserializationError: If the stored data cannot be validated as the model and the adapter is configured to + raise on validation error. + + Note: + When raise_on_validation_error=False and validation fails for any key, that position in the returned list + will contain the default value (which may be None). The method returns a complete list matching the order + and length of the input keys, with defaults substituted for missing or invalid entries. + """ + collection = collection or self._default_collection + + values: list[dict[str, Any] | None] = self._key_value.get_many(keys=keys, collection=collection) + + result: list[T | None] = [] + for value in values: + if value is None: + result.append(default) + else: + validated = self._validate_model(value=value) + result.append(validated if validated is not None else default) + return result + + def put(self, key: str, value: T, *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: + """Serialize and store a model. + + Propagates SerializationError if the model cannot be serialized. + """ + collection = collection or self._default_collection + + value_dict: dict[str, Any] = self._serialize_model(value=value) + + self._key_value.put(key=key, value=value_dict, collection=collection, ttl=ttl) + + def put_many( + self, keys: Sequence[str], values: Sequence[T], *, collection: str | None = None, ttl: SupportsFloat | None = None + ) -> None: + """Serialize and store multiple models, preserving order alignment with keys.""" + collection = collection or self._default_collection + + value_dicts: list[dict[str, Any]] = [self._serialize_model(value=value) for value in values] + + self._key_value.put_many(keys=keys, values=value_dicts, collection=collection, ttl=ttl) + + def delete(self, key: str, *, collection: str | None = None) -> bool: + """Delete a model by key. Returns True if a value was deleted, else False.""" + collection = collection or self._default_collection + + return self._key_value.delete(key=key, collection=collection) + + def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int: + """Delete multiple models by key. Returns the count of deleted entries.""" + collection = collection or self._default_collection + + return self._key_value.delete_many(keys=keys, collection=collection) + + def ttl(self, key: str, *, collection: str | None = None) -> tuple[T | None, float | None]: + """Get a model and its TTL seconds if present. + + Args: + key: The key to retrieve. + collection: The collection to use. If not provided, uses the default collection. + + Returns: + A tuple of (model, ttl_seconds). Returns (None, None) if the key is missing or validation fails. + + Note: + When validation fails and raise_on_validation_error=False, returns (None, None) even if TTL data exists. + When validation fails and raise_on_validation_error=True, raises DeserializationError. + """ + collection = collection or self._default_collection + + entry: dict[str, Any] | None + ttl_info: float | None + + (entry, ttl_info) = self._key_value.ttl(key=key, collection=collection) + + if entry is None: + return (None, None) + + validated_model = self._validate_model(value=entry) + if validated_model is not None: + return (validated_model, ttl_info) + + return (None, None) + + def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[T | None, float | None]]: + """Batch get models with TTLs. Each element is (model|None, ttl_seconds|None).""" + collection = collection or self._default_collection + + entries: list[tuple[dict[str, Any] | None, float | None]] = self._key_value.ttl_many(keys=keys, collection=collection) + + return [(self._validate_model(value=entry) if entry is not None else None, ttl_info) for (entry, ttl_info) in entries] diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/protocols/key_value.py b/key-value/key-value-sync/src/key_value/sync/code_gen/protocols/key_value.py index 2ed49412..d5d9c70d 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/protocols/key_value.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/protocols/key_value.py @@ -9,8 +9,9 @@ class KeyValueProtocol(Protocol): """A subset of KV operations: get/put/delete and TTL variants, including bulk calls. - This protocol defines the minimal contract for key-value store implementations. All methods are - async and may raise exceptions on connection failures, validation errors, or other operational issues. + This protocol defines the minimal contract for key-value store implementations. All methods may + raise exceptions on connection failures, validation errors, or other operational issues. + Implementations should handle backend-specific errors appropriately. """ diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py index e57cc10c..1c02abda 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py @@ -252,7 +252,6 @@ def _put_managed_entries( created_at: datetime, expires_at: datetime | None, ) -> None: - """Store multiple managed entries by key in the specified collection. Args: diff --git a/key-value/key-value-sync/tests/code_gen/adapters/test_dataclass.py b/key-value/key-value-sync/tests/code_gen/adapters/test_dataclass.py new file mode 100644 index 00000000..ec618f3d --- /dev/null +++ b/key-value/key-value-sync/tests/code_gen/adapters/test_dataclass.py @@ -0,0 +1,256 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file 'test_dataclass.py' +# DO NOT CHANGE! Change the original file instead. +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any + +import pytest +from inline_snapshot import snapshot +from key_value.shared.errors import DeserializationError + +from key_value.sync.code_gen.adapters.dataclass import DataclassAdapter +from key_value.sync.code_gen.stores.memory.store import MemoryStore + + +@dataclass +class User: + name: str + age: int + email: str + + +@dataclass +class UpdatedUser: + name: str + age: int + email: str + is_admin: bool + + +@dataclass +class Product: + name: str + price: float + quantity: int + + +@dataclass +class Address: + street: str + city: str + zip_code: str + + +@dataclass +class UserWithAddress: + name: str + age: int + address: Address + + +@dataclass +class Order: + created_at: datetime + updated_at: datetime + user: User + product: Product + paid: bool = False + + +FIXED_CREATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=12, minute=0, second=0, tzinfo=timezone.utc) +FIXED_UPDATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=15, minute=0, second=0, tzinfo=timezone.utc) + +SAMPLE_USER: User = User(name="John Doe", email="john.doe@example.com", age=30) +SAMPLE_USER_2: User = User(name="Jane Doe", email="jane.doe@example.com", age=25) +SAMPLE_PRODUCT: Product = Product(name="Widget", price=29.99, quantity=10) +SAMPLE_ADDRESS: Address = Address(street="123 Main St", city="Springfield", zip_code="12345") +SAMPLE_USER_WITH_ADDRESS: UserWithAddress = UserWithAddress(name="John Doe", age=30, address=SAMPLE_ADDRESS) +SAMPLE_ORDER: Order = Order(created_at=FIXED_CREATED_AT, updated_at=FIXED_UPDATED_AT, user=SAMPLE_USER, product=SAMPLE_PRODUCT, paid=False) + +TEST_COLLECTION: str = "test_collection" +TEST_KEY: str = "test_key" +TEST_KEY_2: str = "test_key_2" + + +class TestDataclassAdapter: + @pytest.fixture + def store(self) -> MemoryStore: + return MemoryStore() + + @pytest.fixture + def user_adapter(self, store: MemoryStore) -> DataclassAdapter[User]: + return DataclassAdapter[User](key_value=store, dataclass_type=User) + + @pytest.fixture + def updated_user_adapter(self, store: MemoryStore) -> DataclassAdapter[UpdatedUser]: + return DataclassAdapter[UpdatedUser](key_value=store, dataclass_type=UpdatedUser) + + @pytest.fixture + def product_adapter(self, store: MemoryStore) -> DataclassAdapter[Product]: + return DataclassAdapter[Product](key_value=store, dataclass_type=Product) + + @pytest.fixture + def product_list_adapter(self, store: MemoryStore) -> DataclassAdapter[list[Product]]: + return DataclassAdapter[list[Product]](key_value=store, dataclass_type=list[Product]) + + @pytest.fixture + def user_with_address_adapter(self, store: MemoryStore) -> DataclassAdapter[UserWithAddress]: + return DataclassAdapter[UserWithAddress](key_value=store, dataclass_type=UserWithAddress) + + @pytest.fixture + def order_adapter(self, store: MemoryStore) -> DataclassAdapter[Order]: + return DataclassAdapter[Order](key_value=store, dataclass_type=Order) + + def test_simple_adapter(self, user_adapter: DataclassAdapter[User]): + """Test basic put/get/delete operations with a simple dataclass.""" + user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) + cached_user: User | None = user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + assert cached_user == SAMPLE_USER + + assert user_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) + + assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + + def test_simple_adapter_with_default(self, user_adapter: DataclassAdapter[User]): + """Test default value handling.""" + assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER + + user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_2) + assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER_2 + + assert user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot( + [SAMPLE_USER_2, SAMPLE_USER] + ) + + def test_simple_adapter_with_validation_error_ignore( + self, user_adapter: DataclassAdapter[User], updated_user_adapter: DataclassAdapter[UpdatedUser] + ): + """Test that validation errors return None when raise_on_validation_error is False.""" + user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) + + # UpdatedUser requires is_admin field which doesn't exist in stored User + updated_user = updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + assert updated_user is None + + def test_simple_adapter_with_validation_error_raise( + self, user_adapter: DataclassAdapter[User], updated_user_adapter: DataclassAdapter[UpdatedUser] + ): + """Test that validation errors raise DeserializationError when raise_on_validation_error is True.""" + user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) + updated_user_adapter._raise_on_validation_error = True # pyright: ignore[reportPrivateUsage] + with pytest.raises(DeserializationError): + updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + + def test_nested_dataclass(self, user_with_address_adapter: DataclassAdapter[UserWithAddress]): + """Test that nested dataclasses are properly serialized and deserialized.""" + user_with_address_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_WITH_ADDRESS) + cached_user: UserWithAddress | None = user_with_address_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + assert cached_user == SAMPLE_USER_WITH_ADDRESS + assert cached_user is not None + assert cached_user.address.street == "123 Main St" + + def test_complex_adapter(self, order_adapter: DataclassAdapter[Order]): + """Test complex dataclass with nested objects and TTL.""" + order_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_ORDER, ttl=10) + cached_order: Order | None = order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + assert cached_order == SAMPLE_ORDER + + assert order_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) + assert order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + + def test_complex_adapter_with_list(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore): + """Test list dataclass serialization with proper wrapping.""" + product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT], ttl=10) + cached_products: list[Product] | None = product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + assert cached_products == [SAMPLE_PRODUCT, SAMPLE_PRODUCT] + + # We need to ensure our memory store doesn't hold an entry with an array + raw_collection = store._cache.get(TEST_COLLECTION) # pyright: ignore[reportPrivateUsage] + assert raw_collection is not None + + raw_entry = raw_collection.get(TEST_KEY) + assert raw_entry is not None + assert isinstance(raw_entry.value, dict) + assert raw_entry.value == snapshot( + {"items": [{"name": "Widget", "price": 29.99, "quantity": 10}, {"name": "Widget", "price": 29.99, "quantity": 10}]} + ) + + assert product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) + assert product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + + def test_batch_operations(self, user_adapter: DataclassAdapter[User]): + """Test batch put/get/delete operations.""" + keys = [TEST_KEY, TEST_KEY_2] + users = [SAMPLE_USER, SAMPLE_USER_2] + + # Test put_many + user_adapter.put_many(collection=TEST_COLLECTION, keys=keys, values=users) + + # Test get_many + cached_users = user_adapter.get_many(collection=TEST_COLLECTION, keys=keys) + assert cached_users == users + + # Test delete_many + deleted_count = user_adapter.delete_many(collection=TEST_COLLECTION, keys=keys) + assert deleted_count == 2 + + # Verify deletion + cached_users_after_delete = user_adapter.get_many(collection=TEST_COLLECTION, keys=keys) + assert cached_users_after_delete == [None, None] + + def test_ttl_operations(self, user_adapter: DataclassAdapter[User]): + """Test TTL-related operations.""" + # Test single TTL + user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER, ttl=10) + (user, ttl) = user_adapter.ttl(collection=TEST_COLLECTION, key=TEST_KEY) + assert user == SAMPLE_USER + assert ttl is not None + assert ttl > 0 + + # Test ttl_many + user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY_2, value=SAMPLE_USER_2, ttl=20) + ttl_results = user_adapter.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2]) + assert len(ttl_results) == 2 + assert ttl_results[0][0] == SAMPLE_USER + assert ttl_results[1][0] == SAMPLE_USER_2 + + def test_dataclass_validation_on_init(self, store: MemoryStore): + """Test that non-dataclass types are rejected.""" + with pytest.raises(TypeError, match="is not a dataclass"): + DataclassAdapter[str](key_value=store, dataclass_type=str) # type: ignore[type-var] + + def test_default_collection(self, store: MemoryStore): + """Test that default collection is used when not specified.""" + adapter = DataclassAdapter[User](key_value=store, dataclass_type=User, default_collection=TEST_COLLECTION) + + adapter.put(key=TEST_KEY, value=SAMPLE_USER) + cached_user = adapter.get(key=TEST_KEY) + assert cached_user == SAMPLE_USER + + assert adapter.delete(key=TEST_KEY) + + def test_ttl_with_empty_list(self, product_list_adapter: DataclassAdapter[list[Product]]): + """Test that TTL with empty list returns correctly (not None).""" + product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[], ttl=10) + (value, ttl) = product_list_adapter.ttl(collection=TEST_COLLECTION, key=TEST_KEY) + assert value == [] + assert ttl is not None + assert ttl > 0 + + def test_list_payload_missing_items_returns_none(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore): + """Test that list payload without 'items' wrapper returns None when raise_on_validation_error is False.""" + # Manually insert malformed payload without the 'items' wrapper + # The payload is a dict but without the expected 'items' key for list models + malformed_payload: dict[str, Any] = {"wrong": []} + store.put(collection=TEST_COLLECTION, key=TEST_KEY, value=malformed_payload) + assert product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + + def test_list_payload_missing_items_raises(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore): + """Test that list payload without 'items' wrapper raises DeserializationError when configured.""" + product_list_adapter._raise_on_validation_error = True # pyright: ignore[reportPrivateUsage] + # Manually insert malformed payload without the 'items' wrapper + malformed_payload: dict[str, Any] = {"wrong": []} + store.put(collection=TEST_COLLECTION, key=TEST_KEY, value=malformed_payload) + with pytest.raises(DeserializationError, match="missing 'items'"): + product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) diff --git a/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py b/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py index 3cef2742..1b510fab 100644 --- a/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py +++ b/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py @@ -94,6 +94,14 @@ def test_simple_adapter_with_default(self, user_adapter: PydanticAdapter[User]): [SAMPLE_USER_2, SAMPLE_USER] ) + def test_simple_adapter_with_list(self, product_list_adapter: PydanticAdapter[list[Product]]): + product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT]) + cached_products: list[Product] | None = product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + assert cached_products == [SAMPLE_PRODUCT, SAMPLE_PRODUCT] + + assert product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) + assert product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + def test_simple_adapter_with_validation_error_ignore( self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser] ):