diff --git a/dataframely/_serialization.py b/dataframely/_serialization.py index e973b736..4a60fe79 100644 --- a/dataframely/_serialization.py +++ b/dataframely/_serialization.py @@ -10,8 +10,6 @@ import polars as pl -SCHEMA_METADATA_KEY = "dataframely_schema" -COLLECTION_METADATA_KEY = "dataframely_collection" SERIALIZATION_FORMAT_VERSION = "1" diff --git a/dataframely/_storage/__init__.py b/dataframely/_storage/__init__.py new file mode 100644 index 00000000..f090c20e --- /dev/null +++ b/dataframely/_storage/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) QuantCo 2025-2025 +# SPDX-License-Identifier: BSD-3-Clause + +from ._base import StorageBackend + +__all__ = [ + "StorageBackend", +] diff --git a/dataframely/_storage/_base.py b/dataframely/_storage/_base.py new file mode 100644 index 00000000..b6ec6667 --- /dev/null +++ b/dataframely/_storage/_base.py @@ -0,0 +1,198 @@ +# Copyright (c) QuantCo 2025-2025 +# SPDX-License-Identifier: BSD-3-Clause + +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Any + +import polars as pl + +SerializedSchema = str +SerializedCollection = str +SerializedRules = str + + +class StorageBackend(ABC): + """Base class for storage backends. + + A storage backend encapsulates a way of serializing and deserializing dataframlely + data-/lazyframes and collections. This base class provides a unified interface for + all such use cases. + + The interface is designed to operate on data provided as polars frames, and metadata + provided as serialized strings. This design is meant to limit the coupling between + the Schema/Collection classes and specifics of how data and metadata is stored. + """ + + # ----------------------------------- Schemas ------------------------------------- + @abstractmethod + def sink_frame( + self, lf: pl.LazyFrame, serialized_schema: SerializedSchema, **kwargs: Any + ) -> None: + """Stream the contents of a dataframe, and its metadata to the storage backend. + + Args: + lf: A frame containing the data to be stored. + serialized_schema: String-serialized schema information. + kwargs: Additional keyword arguments to pass to the underlying storage + implementation. + """ + + @abstractmethod + def write_frame( + self, df: pl.DataFrame, serialized_schema: SerializedSchema, **kwargs: Any + ) -> None: + """Write the contents of a dataframe, and its metadata to the storage backend. + + Args: + df: A dataframe containing the data to be stored. + frame: String-serialized schema information. + kwargs: Additional keyword arguments to pass to the underlying storage + implementation. + """ + + @abstractmethod + def scan_frame(self, **kwargs: Any) -> tuple[pl.LazyFrame, SerializedSchema | None]: + """Lazily read frame data and metadata from the storage backend. + + Args: + kwargs: Keyword arguments to pass to the underlying storage. + Refer to the individual implementation to see which keywords + are available. + Returns: + A tuple of the lazy frame data and metadata if available. + """ + + @abstractmethod + def read_frame(self, **kwargs: Any) -> tuple[pl.DataFrame, SerializedSchema | None]: + """Eagerly read frame data and metadata from the storage backend. + + Args: + kwargs: Keyword arguments to pass to the underlying storage. + Refer to the individual implementation to see which keywords + are available. + Returns: + A tuple of the lazy frame data and metadata if available. + """ + + # ------------------------------ Collections --------------------------------------- + @abstractmethod + def sink_collection( + self, + dfs: dict[str, pl.LazyFrame], + serialized_collection: SerializedCollection, + serialized_schemas: dict[str, str], + **kwargs: Any, + ) -> None: + """Stream the members of this collection into the storage backend. + + Args: + dfs: Dictionary containing the data to be stored. + serialized_collection: String-serialized information about the origin Collection. + serialized_schemas: String-serialized information about the individual Schemas + for each of the member frames. This information is also logically included + in the collection metadata, but it is passed separately here to ensure that + each member can also be read back as an individual frame. + """ + + @abstractmethod + def write_collection( + self, + dfs: dict[str, pl.LazyFrame], + serialized_collection: SerializedCollection, + serialized_schemas: dict[str, str], + **kwargs: Any, + ) -> None: + """Write the members of this collection into the storage backend. + + Args: + dfs: Dictionary containing the data to be stored. + serialized_collection: String-serialized information about the origin Collection. + serialized_schemas: String-serialized information about the individual Schemas + for each of the member frames. This information is also logically included + in the collection metadata, but it is passed separately here to ensure that + each member can also be read back as an individual frame. + """ + + @abstractmethod + def scan_collection( + self, members: Iterable[str], **kwargs: Any + ) -> tuple[dict[str, pl.LazyFrame], list[SerializedCollection | None]]: + """Lazily read all collection members from the storage backend. + + Args: + members: Collection member names to read. + kwargs: Additional keyword arguments to pass to the underlying storage. + Refer to the individual implementation to see which keywords are available. + Returns: + A tuple of the collection data and metadata if available. + Depending on the storage implementation, multiple copies of the metadata + may be available, which are returned as a list. + It is up to the caller to decide how to handle the presence/absence/consistency + of the returned values. + """ + + @abstractmethod + def read_collection( + self, members: Iterable[str], **kwargs: Any + ) -> tuple[dict[str, pl.LazyFrame], list[SerializedCollection | None]]: + """Lazily read all collection members from the storage backend. + + Args: + members: Collection member names to read. + kwargs: Additional keyword arguments to pass to the underlying storage. + Refer to the individual implementation to see which keywords are available. + Returns: + A tuple of the collection data and metadata if available. + Depending on the storage implementation, multiple copies of the metadata + may be available, which are returned as a list. + It is up to the caller to decide how to handle the presence/absence/consistency + of the returned values. + """ + + # ------------------------------ Failure Info -------------------------------------- + @abstractmethod + def sink_failure_info( + self, + lf: pl.LazyFrame, + serialized_rules: SerializedRules, + serialized_schema: SerializedSchema, + **kwargs: Any, + ) -> None: + """Stream the failure info to the storage backend. + + Args: + lf: LazyFrame backing the failure info. + serialized_rules: JSON-serialized list of rule column names + used for validation. + serialized_schema: String-serialized schema information. + """ + + @abstractmethod + def write_failure_info( + self, + df: pl.DataFrame, + serialized_rules: SerializedRules, + serialized_schema: SerializedSchema, + **kwargs: Any, + ) -> None: + """Write the failure info to the storage backend. + + Args: + df: DataFrame backing the failure info. + serialized_rules: JSON-serialized list of rule column names + used for validation. + serialized_schema: String-serialized schema information. + """ + + @abstractmethod + def scan_failure_info( + self, **kwargs: Any + ) -> tuple[pl.LazyFrame, SerializedRules, SerializedSchema]: + """Lazily read the failure info from the storage backend.""" + + @abstractmethod + def read_failure_info( + self, **kwargs: Any + ) -> tuple[pl.DataFrame, SerializedRules, SerializedSchema]: + """Read the failure info from the storage backend.""" diff --git a/dataframely/_storage/parquet.py b/dataframely/_storage/parquet.py new file mode 100644 index 00000000..c7dd9be8 --- /dev/null +++ b/dataframely/_storage/parquet.py @@ -0,0 +1,248 @@ +# Copyright (c) QuantCo 2025-2025 +# SPDX-License-Identifier: BSD-3-Clause + +from collections.abc import Iterable +from pathlib import Path +from typing import Any + +import polars as pl + +from ._base import ( + SerializedCollection, + SerializedRules, + SerializedSchema, + StorageBackend, +) + +SCHEMA_METADATA_KEY = "dataframely_schema" +COLLECTION_METADATA_KEY = "dataframely_collection" +RULE_METADATA_KEY = "dataframely_rule_columns" + + +class ParquetStorageBackend(StorageBackend): + """IO manager that stores data and metadata in parquet files on a file system. + + Single frames are stored as individual parquet files + + Collections are stored as directories. + """ + + # ----------------------------------- Schemas ------------------------------------- + def sink_frame( + self, lf: pl.LazyFrame, serialized_schema: SerializedSchema, **kwargs: Any + ) -> None: + file = kwargs.pop("file") + metadata = kwargs.pop("metadata", {}) + lf.sink_parquet( + file, + metadata={**metadata, SCHEMA_METADATA_KEY: serialized_schema}, + **kwargs, + ) + + def write_frame( + self, df: pl.DataFrame, serialized_schema: SerializedSchema, **kwargs: Any + ) -> None: + file = kwargs.pop("file") + metadata = kwargs.pop("metadata", {}) + df.write_parquet( + file, + metadata={**metadata, SCHEMA_METADATA_KEY: serialized_schema}, + **kwargs, + ) + + def scan_frame(self, **kwargs: Any) -> tuple[pl.LazyFrame, SerializedSchema | None]: + source = kwargs.pop("source") + lf = pl.scan_parquet(source, **kwargs) + metadata = _read_serialized_schema(source) + return lf, metadata + + def read_frame(self, **kwargs: Any) -> tuple[pl.DataFrame, SerializedSchema | None]: + source = kwargs.pop("source") + df = pl.read_parquet(source, **kwargs) + metadata = _read_serialized_schema(source) + return df, metadata + + # ------------------------------ Collections --------------------------------------- + def sink_collection( + self, + dfs: dict[str, pl.LazyFrame], + serialized_collection: SerializedCollection, + serialized_schemas: dict[str, str], + **kwargs: Any, + ) -> None: + path = Path(kwargs.pop("directory")) + + # The collection schema is serialized as part of the member parquet metadata + kwargs["metadata"] = kwargs.get("metadata", {}) | { + COLLECTION_METADATA_KEY: serialized_collection + } + + for key, lf in dfs.items(): + destination = ( + path / key if "partition_by" in kwargs else path / f"{key}.parquet" + ) + self.sink_frame( + lf, + serialized_schema=serialized_schemas[key], + file=destination, + **kwargs, + ) + + def write_collection( + self, + dfs: dict[str, pl.LazyFrame], + serialized_collection: SerializedCollection, + serialized_schemas: dict[str, str], + **kwargs: Any, + ) -> None: + path = Path(kwargs.pop("directory")) + + # The collection schema is serialized as part of the member parquet metadata + kwargs["metadata"] = kwargs.get("metadata", {}) | { + COLLECTION_METADATA_KEY: serialized_collection + } + + for key, lf in dfs.items(): + destination = ( + path / key if "partition_by" in kwargs else path / f"{key}.parquet" + ) + self.write_frame( + lf.collect(), + serialized_schema=serialized_schemas[key], + file=destination, + **kwargs, + ) + + def scan_collection( + self, members: Iterable[str], **kwargs: Any + ) -> tuple[dict[str, pl.LazyFrame], list[SerializedCollection | None]]: + path = Path(kwargs.pop("directory")) + return self._collection_from_parquet( + path=path, members=members, scan=True, **kwargs + ) + + def read_collection( + self, members: Iterable[str], **kwargs: Any + ) -> tuple[dict[str, pl.LazyFrame], list[SerializedCollection | None]]: + path = Path(kwargs.pop("directory")) + return self._collection_from_parquet( + path=path, members=members, scan=False, **kwargs + ) + + def _collection_from_parquet( + self, path: Path, members: Iterable[str], scan: bool, **kwargs: Any + ) -> tuple[dict[str, pl.LazyFrame], list[SerializedCollection | None]]: + # Utility method encapsulating the logic that is common + # between lazy and eager reads + data = {} + collection_types = [] + + for key in members: + if (source_path := self._member_source_path(path, key)) is not None: + data[key] = ( + pl.scan_parquet(source_path, **kwargs) + if scan + else pl.read_parquet(source_path, **kwargs).lazy() + ) + if source_path.is_file(): + collection_types.append(_read_serialized_collection(source_path)) + else: + for file in source_path.glob("**/*.parquet"): + collection_types.append(_read_serialized_collection(file)) + + # Backward compatibility: If the parquets do not have schema information, + # fall back to looking for schema.json + if not any(collection_types) and (schema_file := path / "schema.json").exists(): + collection_types.append(schema_file.read_text()) + + return data, collection_types + + @classmethod + def _member_source_path(cls, base_path: Path, name: str) -> Path | None: + if (path := base_path / name).exists() and base_path.is_dir(): + # We assume that the member is stored as a hive-partitioned dataset + return path + if (path := base_path / f"{name}.parquet").exists(): + # We assume that the member is stored as a single parquet file + return path + return None + + # ------------------------------ Failure Info -------------------------------------- + def sink_failure_info( + self, + lf: pl.LazyFrame, + serialized_rules: SerializedRules, + serialized_schema: SerializedSchema, + **kwargs: Any, + ) -> None: + self._write_failure_info( + df=lf, + serialized_rules=serialized_rules, + serialized_schema=serialized_schema, + **kwargs, + ) + + def write_failure_info( + self, + df: pl.DataFrame, + serialized_rules: SerializedRules, + serialized_schema: SerializedSchema, + **kwargs: Any, + ) -> None: + self._write_failure_info( + df=df, + serialized_rules=serialized_rules, + serialized_schema=serialized_schema, + **kwargs, + ) + + def _write_failure_info( + self, + df: pl.DataFrame | pl.LazyFrame, + serialized_rules: SerializedRules, + serialized_schema: SerializedSchema, + **kwargs: Any, + ) -> None: + file = kwargs.pop("file") + metadata = kwargs.pop("metadata", {}) + + metadata[RULE_METADATA_KEY] = serialized_rules + metadata[SCHEMA_METADATA_KEY] = serialized_schema + + if isinstance(df, pl.DataFrame): + df.write_parquet(file, metadata=metadata, **kwargs) + else: + df.sink_parquet(file, metadata=metadata, **kwargs) + + def scan_failure_info( + self, **kwargs: Any + ) -> tuple[pl.LazyFrame, SerializedRules, SerializedSchema]: + file = kwargs.pop("file") + metadata = pl.read_parquet_metadata(file) + schema_metadata = metadata.get(SCHEMA_METADATA_KEY) + + rule_metadata = metadata.get(RULE_METADATA_KEY) + if schema_metadata is None or rule_metadata is None: + raise ValueError("The parquet file does not contain the required metadata.") + lf = pl.scan_parquet(file, **kwargs) + return lf, rule_metadata, schema_metadata + + def read_failure_info( + self, **kwargs: Any + ) -> tuple[pl.DataFrame, SerializedRules, SerializedSchema]: + lf, rule_metadata, schema_metadata = self.scan_failure_info(**kwargs) + return ( + lf.collect(), + rule_metadata, + schema_metadata, + ) + + +def _read_serialized_collection(path: Path) -> SerializedCollection | None: + meta = pl.read_parquet_metadata(path) + return meta.get(COLLECTION_METADATA_KEY) + + +def _read_serialized_schema(path: Path) -> SerializedSchema | None: + meta = pl.read_parquet_metadata(path) + return meta.get(SCHEMA_METADATA_KEY) diff --git a/dataframely/collection.py b/dataframely/collection.py index 4c81e6a1..5b3b367e 100644 --- a/dataframely/collection.py +++ b/dataframely/collection.py @@ -19,12 +19,13 @@ from ._filter import Filter from ._polars import FrameType from ._serialization import ( - COLLECTION_METADATA_KEY, SERIALIZATION_FORMAT_VERSION, SchemaJSONDecoder, SchemaJSONEncoder, serialization_versions, ) +from ._storage import StorageBackend +from ._storage.parquet import COLLECTION_METADATA_KEY, ParquetStorageBackend from ._typing import LazyFrame, Validation from .exc import ( MemberValidationError, @@ -741,7 +742,7 @@ def write_parquet(self, directory: str | Path, **kwargs: Any) -> None: Attention: This method suffers from the same limitations as :meth:`Schema.serialize`. """ - self._to_parquet(directory, sink=False, **kwargs) + self._write(ParquetStorageBackend(), directory=directory, **kwargs) def sink_parquet(self, directory: str | Path, **kwargs: Any) -> None: """Stream the members of this collection into parquet files in a directory. @@ -761,34 +762,7 @@ def sink_parquet(self, directory: str | Path, **kwargs: Any) -> None: Attention: This method suffers from the same limitations as :meth:`Schema.serialize`. """ - self._to_parquet(directory, sink=True, **kwargs) - - def _to_parquet(self, directory: str | Path, *, sink: bool, **kwargs: Any) -> None: - path = Path(directory) if isinstance(directory, str) else directory - path.mkdir(parents=True, exist_ok=True) - - # The collection schema is serialized as part of the member parquet metadata - kwargs["metadata"] = kwargs.get("metadata", {}) | { - COLLECTION_METADATA_KEY: self.serialize() - } - - member_schemas = self.member_schemas() - for key, lf in self.to_dict().items(): - destination = ( - path / key if "partition_by" in kwargs else path / f"{key}.parquet" - ) - if sink: - member_schemas[key].sink_parquet( - lf, # type: ignore - destination, - **kwargs, - ) - else: - member_schemas[key].write_parquet( - lf.collect(), # type: ignore - destination, - **kwargs, - ) + self._sink(ParquetStorageBackend(), directory=directory, **kwargs) @classmethod def read_parquet( @@ -845,14 +819,13 @@ def read_parquet( Be aware that this method suffers from the same limitations as :meth:`serialize`. """ - path = Path(directory) - data, collection_type = cls._from_parquet(path, scan=False, **kwargs) - if not cls._requires_validation_for_reading_parquets( - path, collection_type, validation - ): - cls._validate_input_keys(data) - return cls._init(data) - return cls.validate(data, cast=True) + return cls._read( + backend=ParquetStorageBackend(), + validation=validation, + directory=directory, + lazy=False, + **kwargs, + ) @classmethod def scan_parquet( @@ -912,59 +885,71 @@ def scan_parquet( Be aware that this method suffers from the same limitations as :meth:`serialize`. """ - path = Path(directory) - data, collection_type = cls._from_parquet(path, scan=True, **kwargs) - if not cls._requires_validation_for_reading_parquets( - path, collection_type, validation - ): - cls._validate_input_keys(data) - return cls._init(data) - return cls.validate(data, cast=True) + return cls._read( + backend=ParquetStorageBackend(), + validation=validation, + directory=directory, + lazy=True, + **kwargs, + ) - @classmethod - def _from_parquet( - cls, path: Path, scan: bool, **kwargs: Any - ) -> tuple[dict[str, pl.LazyFrame], type[Collection] | None]: - data = {} - collection_types = set() - for key in cls.members(): - if (source_path := cls._member_source_path(path, key)) is not None: - data[key] = ( - pl.scan_parquet(source_path, **kwargs) - if scan - else pl.read_parquet(source_path, **kwargs).lazy() - ) - if source_path.is_file(): - collection_types.add(read_parquet_metadata_collection(source_path)) - else: - for file in source_path.glob("**/*.parquet"): - collection_types.add(read_parquet_metadata_collection(file)) - collection_type = _reconcile_collection_types(collection_types) + # -------------------------------- Storage --------------------------------------- # + + def _write( + self, backend: StorageBackend, directory: Path | str, **kwargs: Any + ) -> None: + # Utility method encapsulating the interaction with the StorageBackend + + backend.write_collection( + self.to_dict(), + serialized_collection=self.serialize(), + serialized_schemas={ + key: schema.serialize() for key, schema in self.member_schemas().items() + }, + directory=directory, + **kwargs, + ) - # Backward compatibility: If the parquets do not have schema information, - # fall back to looking for schema.json - if (collection_type is None) and (schema_file := path / "schema.json").exists(): - try: - collection_type = deserialize_collection(schema_file.read_text()) - except (JSONDecodeError, plexc.ComputeError): - pass + def _sink( + self, backend: StorageBackend, directory: Path | str, **kwargs: Any + ) -> None: + # Utility method encapsulating the interaction with the StorageBackend - return data, collection_type + backend.sink_collection( + self.to_dict(), + serialized_collection=self.serialize(), + serialized_schemas={ + key: schema.serialize() for key, schema in self.member_schemas().items() + }, + directory=directory, + **kwargs, + ) @classmethod - def _member_source_path(cls, base_path: Path, name: str) -> Path | None: - if (path := base_path / name).exists() and base_path.is_dir(): - # We assume that the member is stored as a hive-partitioned dataset - return path - if (path := base_path / f"{name}.parquet").exists(): - # We assume that the member is stored as a single parquet file - return path - return None + def _read( + cls, backend: StorageBackend, validation: Validation, lazy: bool, **kwargs: Any + ) -> Self: + # Utility method encapsulating the interaction with the StorageBackend + + if lazy: + data, serialized_collection_types = backend.scan_collection( + members=cls.member_schemas().keys(), **kwargs + ) + else: + data, serialized_collection_types = backend.read_collection( + members=cls.member_schemas().keys(), **kwargs + ) + + collection_types = _deserialize_types(serialized_collection_types) + collection_type = _reconcile_collection_types(collection_types) + + if cls._requires_validation_for_reading_parquets(collection_type, validation): + return cls.validate(data, cast=True) + return cls.cast(data) @classmethod def _requires_validation_for_reading_parquets( cls, - directory: Path, collection_type: type[Collection] | None, validation: Validation, ) -> bool: @@ -983,12 +968,10 @@ def _requires_validation_for_reading_parquets( ) if validation == "forbid": raise ValidationRequiredError( - f"Cannot read collection from '{directory!r}' without validation: {msg}." + f"Cannot read collection without validation: {msg}." ) if validation == "warn": - warnings.warn( - f"Reading parquet file from '{directory!r}' requires validation: {msg}." - ) + warnings.warn(f"Reading parquet file requires validation: {msg}.") return True # ----------------------------------- UTILITIES ---------------------------------- # @@ -1102,6 +1085,23 @@ def _extract_keys_if_exist( return {key: data[key] for key in keys if key in data} +def _deserialize_types( + serialized_collection_types: Iterable[str | None], +) -> list[type[Collection]]: + collection_types = [] + collection_type: type[Collection] | None = None + for t in serialized_collection_types: + if t is None: + continue + try: + collection_type = deserialize_collection(t) + collection_types.append(collection_type) + except (JSONDecodeError, plexc.ComputeError): + pass + + return collection_types + + def _reconcile_collection_types( collection_types: Iterable[type[Collection] | None], ) -> type[Collection] | None: diff --git a/dataframely/failure.py b/dataframely/failure.py index 4b6f31cf..5198f470 100644 --- a/dataframely/failure.py +++ b/dataframely/failure.py @@ -13,12 +13,12 @@ from dataframely._base_schema import BaseSchema -from ._serialization import SCHEMA_METADATA_KEY +from ._storage import StorageBackend +from ._storage.parquet import ParquetStorageBackend if TYPE_CHECKING: # pragma: no cover from .schema import Schema -RULE_METADATA_KEY = "dataframely_rule_columns" UNKNOWN_SCHEMA_NAME = "__DATAFRAMELY_UNKNOWN__" S = TypeVar("S", bound=BaseSchema) @@ -98,8 +98,7 @@ def write_parquet(self, file: str | Path | IO[bytes], **kwargs: Any) -> None: Be aware that this method suffers from the same limitations as :meth:`Schema.serialize`. """ - metadata, kwargs = self._build_metadata(**kwargs) - self._df.write_parquet(file, metadata=metadata, **kwargs) + self._write(ParquetStorageBackend(), file=file, **kwargs) def sink_parquet( self, file: str | Path | IO[bytes] | PartitioningScheme, **kwargs: Any @@ -118,16 +117,7 @@ def sink_parquet( Be aware that this method suffers from the same limitations as :meth:`Schema.serialize`. """ - metadata, kwargs = self._build_metadata(**kwargs) - self._lf.sink_parquet(file, metadata=metadata, **kwargs) - - def _build_metadata( - self, **kwargs: dict[str, Any] - ) -> tuple[dict[str, Any], dict[str, Any]]: - metadata = kwargs.pop("metadata", {}) - metadata[RULE_METADATA_KEY] = json.dumps(self._rule_columns) - metadata[SCHEMA_METADATA_KEY] = self.schema.serialize() - return metadata, kwargs + self._sink(ParquetStorageBackend(), file=file, **kwargs) @classmethod def read_parquet( @@ -150,7 +140,9 @@ def read_parquet( Be aware that this method suffers from the same limitations as :meth:`Schema.serialize` """ - return cls._from_parquet(source, scan=False, **kwargs) + return cls._read( + backend=ParquetStorageBackend(), file=source, lazy=False, **kwargs + ) @classmethod def scan_parquet( @@ -171,32 +163,73 @@ def scan_parquet( Be aware that this method suffers from the same limitations as :meth:`Schema.serialize` """ - return cls._from_parquet(source, scan=True, **kwargs) + return cls._read( + backend=ParquetStorageBackend(), file=source, lazy=True, **kwargs + ) + + # -------------------------------- Storage --------------------------------------- # + + def _sink( + self, + backend: StorageBackend, + file: str | Path | IO[bytes] | PartitioningScheme, + **kwargs: Any, + ) -> None: + # Utility method encapsulating the interaction with the StorageBackend + + backend.sink_failure_info( + lf=self._lf, + serialized_rules=json.dumps(self._rule_columns), + serialized_schema=self.schema.serialize(), + file=file, + **kwargs, + ) + + def _write( + self, + backend: StorageBackend, + file: str | Path | IO[bytes] | PartitioningScheme, + **kwargs: Any, + ) -> None: + # Utility method encapsulating the interaction with the StorageBackend + + backend.write_failure_info( + df=self._df, + serialized_rules=json.dumps(self._rule_columns), + serialized_schema=self.schema.serialize(), + file=file, + **kwargs, + ) @classmethod - def _from_parquet( - cls, source: str | Path | IO[bytes], scan: bool, **kwargs: Any + def _read( + cls, + backend: StorageBackend, + file: str | Path | IO[bytes] | PartitioningScheme, + lazy: bool, + **kwargs: Any, ) -> FailureInfo[Schema]: - from .schema import Schema, deserialize_schema + # Utility method encapsulating the interaction with the StorageBackend - metadata = pl.read_parquet_metadata(source) - schema_metadata = metadata.get(SCHEMA_METADATA_KEY) - rule_metadata = metadata.get(RULE_METADATA_KEY) - if schema_metadata is None or rule_metadata is None: - raise ValueError("The parquet file does not contain the required metadata.") + from .schema import Schema, deserialize_schema - lf = ( - pl.scan_parquet(source, **kwargs) - if scan - else pl.read_parquet(source, **kwargs).lazy() - ) - failure_schema = deserialize_schema(schema_metadata, strict=False) or type( + if lazy: + lf, serialized_rules, serialized_schema = backend.scan_failure_info( + file=file, **kwargs + ) + else: + df, serialized_rules, serialized_schema = backend.read_failure_info( + file=file, **kwargs + ) + lf = df.lazy() + + schema = deserialize_schema(serialized_schema, strict=False) or type( UNKNOWN_SCHEMA_NAME, (Schema,), {} ) return FailureInfo( lf, - json.loads(rule_metadata), - schema=failure_schema, + json.loads(serialized_rules), + schema=schema, ) diff --git a/dataframely/schema.py b/dataframely/schema.py index f012560e..42c023f8 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -21,12 +21,13 @@ from ._compat import pa, sa from ._rule import Rule, rule_from_dict, with_evaluation_rules from ._serialization import ( - SCHEMA_METADATA_KEY, SERIALIZATION_FORMAT_VERSION, SchemaJSONDecoder, SchemaJSONEncoder, serialization_versions, ) +from ._storage import StorageBackend +from ._storage.parquet import SCHEMA_METADATA_KEY, ParquetStorageBackend from ._typing import DataFrame, LazyFrame, Validation from ._validation import DtypeCasting, validate_columns, validate_dtypes from .columns import Column, column_from_dict @@ -697,10 +698,7 @@ def write_parquet( Be aware that this method suffers from the same limitations as :meth:`serialize`. """ - metadata = kwargs.pop("metadata", {}) - df.write_parquet( - file, metadata={**metadata, SCHEMA_METADATA_KEY: cls.serialize()}, **kwargs - ) + cls._write(df=df, backend=ParquetStorageBackend(), file=file, **kwargs) @classmethod def sink_parquet( @@ -728,10 +726,7 @@ def sink_parquet( Be aware that this method suffers from the same limitations as :meth:`serialize`. """ - metadata = kwargs.pop("metadata", {}) - lf.sink_parquet( - file, metadata={**metadata, SCHEMA_METADATA_KEY: cls.serialize()}, **kwargs - ) + cls._sink(lf=lf, backend=ParquetStorageBackend(), file=file, **kwargs) @classmethod def read_parquet( @@ -780,9 +775,13 @@ def read_parquet( Be aware that this method suffers from the same limitations as :meth:`serialize`. """ - if not cls._requires_validation_for_reading_parquet(source, validation): - return pl.read_parquet(source, **kwargs) # type: ignore - return cls.validate(pl.read_parquet(source, **kwargs), cast=True) + return cls._read( + ParquetStorageBackend(), + validation=validation, + lazy=False, + source=source, + **kwargs, + ) @classmethod def scan_parquet( @@ -836,13 +835,20 @@ def scan_parquet( Be aware that this method suffers from the same limitations as :meth:`serialize`. """ - if not cls._requires_validation_for_reading_parquet(source, validation): - return pl.scan_parquet(source, **kwargs) # type: ignore - return cls.validate(pl.read_parquet(source, **kwargs), cast=True).lazy() + return cls._read( + ParquetStorageBackend(), + validation=validation, + lazy=True, + source=source, + **kwargs, + ) @classmethod def _requires_validation_for_reading_parquet( - cls, source: FileSource, validation: Validation + cls, + deserialized_schema: type[Schema] | None, + validation: Validation, + source: str, ) -> bool: if validation == "skip": return False @@ -850,20 +856,16 @@ def _requires_validation_for_reading_parquet( # First, we check whether the source provides the dataframely schema. If it # does, we check whether it matches this schema. If it does, we assume that the # data adheres to the schema and we do not need to run validation. - serialized_schema = ( - read_parquet_metadata_schema(source) - if not isinstance(source, list) - else None - ) - if serialized_schema is not None: - if cls.matches(serialized_schema): + + if deserialized_schema is not None: + if cls.matches(deserialized_schema): return False # Otherwise, we definitely need to run validation. However, we emit different # information to the user depending on the value of `validate`. msg = ( "current schema does not match stored schema" - if serialized_schema is not None + if deserialized_schema is not None else "no schema to check validity can be read from the source" ) if validation == "forbid": @@ -876,6 +878,69 @@ def _requires_validation_for_reading_parquet( ) return True + # --------------------------------- Storage -------------------------------------- # + + @classmethod + def _write(cls, df: pl.DataFrame, backend: StorageBackend, **kwargs: Any) -> None: + backend.write_frame(df=df, serialized_schema=cls.serialize(), **kwargs) + + @classmethod + def _sink(cls, lf: pl.LazyFrame, backend: StorageBackend, **kwargs: Any) -> None: + backend.sink_frame(lf=lf, serialized_schema=cls.serialize(), **kwargs) + + @overload + @classmethod + def _read( + cls, + backend: StorageBackend, + validation: Validation, + lazy: Literal[True], + **kwargs: Any, + ) -> LazyFrame[Self]: ... + + @overload + @classmethod + def _read( + cls, + backend: StorageBackend, + validation: Validation, + lazy: Literal[False], + **kwargs: Any, + ) -> DataFrame[Self]: ... + + @classmethod + def _read( + cls, backend: StorageBackend, validation: Validation, lazy: bool, **kwargs: Any + ) -> LazyFrame[Self] | DataFrame[Self]: + source = kwargs.pop("source") + + # Load + if lazy: + lf, serialized_schema = backend.scan_frame(source=source) + else: + df, serialized_schema = backend.read_frame(source=source) + lf = df.lazy() + + deserialized_schema = ( + deserialize_schema(serialized_schema) if serialized_schema else None + ) + + # Smart validation + if cls._requires_validation_for_reading_parquet( + deserialized_schema, validation, source=str(source) + ): + validated = cls.validate(lf, cast=True) + if lazy: + return validated.lazy() + else: + return validated + + casted = cls.cast(lf) + if lazy: + return casted + else: + return casted.collect() + # ----------------------------- THIRD-PARTY PACKAGES ----------------------------- # @classmethod @@ -958,6 +1023,7 @@ def read_parquet_metadata_schema( is found or the deserialization fails. """ metadata = pl.read_parquet_metadata(source) + if (schema_metadata := metadata.get(SCHEMA_METADATA_KEY)) is not None: return deserialize_schema(schema_metadata, strict=False) return None diff --git a/tests/collection/test_read_write_parquet.py b/tests/collection/test_read_write_parquet.py index eedc9626..c8c7985b 100644 --- a/tests/collection/test_read_write_parquet.py +++ b/tests/collection/test_read_write_parquet.py @@ -12,7 +12,7 @@ from polars.testing import assert_frame_equal import dataframely as dy -from dataframely._serialization import COLLECTION_METADATA_KEY +from dataframely._storage.parquet import COLLECTION_METADATA_KEY from dataframely.collection import _reconcile_collection_types from dataframely.exc import ValidationRequiredError from dataframely.testing import create_collection, create_schema @@ -371,6 +371,7 @@ class MyCollection2(dy.Collection): ([MyCollection], MyCollection), # One missing type, cannot be sure ([MyCollection, None], None), + ([None, MyCollection], None), # Inconsistent types, treat like no information available ([MyCollection, MyCollection2], None), ], @@ -384,11 +385,15 @@ def test_reconcile_collection_types( # ---------------------------------- MANUAL METADATA --------------------------------- # -def test_read_invalid_parquet_metadata_collection(tmp_path: Path) -> None: +@pytest.mark.parametrize("metadata", [None, {COLLECTION_METADATA_KEY: "invalid"}]) +def test_read_invalid_parquet_metadata_collection( + tmp_path: Path, metadata: dict | None +) -> None: # Arrange df = pl.DataFrame({"a": [1, 2, 3]}) df.write_parquet( - tmp_path / "df.parquet", metadata={COLLECTION_METADATA_KEY: "invalid"} + tmp_path / "df.parquet", + metadata=metadata, ) # Act diff --git a/tests/schema/test_read_write_parquet.py b/tests/schema/test_read_write_parquet.py index 2e944cb0..7095c5b0 100644 --- a/tests/schema/test_read_write_parquet.py +++ b/tests/schema/test_read_write_parquet.py @@ -10,7 +10,7 @@ from polars.testing import assert_frame_equal import dataframely as dy -from dataframely._serialization import SCHEMA_METADATA_KEY +from dataframely._storage.parquet import SCHEMA_METADATA_KEY from dataframely.exc import ValidationRequiredError from dataframely.testing import create_schema @@ -222,10 +222,13 @@ def test_read_write_parquet_validation_skip_invalid_schema( # ---------------------------------- MANUAL METADATA --------------------------------- # -def test_read_invalid_parquet_metadata_schema(tmp_path: Path) -> None: +@pytest.mark.parametrize("metadata", [{SCHEMA_METADATA_KEY: "invalid"}, None]) +def test_read_invalid_parquet_metadata_schema( + tmp_path: Path, metadata: dict | None +) -> None: # Arrange df = pl.DataFrame({"a": [1, 2, 3]}) - df.write_parquet(tmp_path / "df.parquet", metadata={SCHEMA_METADATA_KEY: "invalid"}) + df.write_parquet(tmp_path / "df.parquet", metadata=metadata) # Act schema = dy.read_parquet_metadata_schema(tmp_path / "df.parquet") diff --git a/tests/test_failure_info.py b/tests/test_failure_info.py index 5deed699..bfeaa27b 100644 --- a/tests/test_failure_info.py +++ b/tests/test_failure_info.py @@ -9,8 +9,8 @@ from polars.testing import assert_frame_equal import dataframely as dy -from dataframely._serialization import SCHEMA_METADATA_KEY -from dataframely.failure import RULE_METADATA_KEY, UNKNOWN_SCHEMA_NAME, FailureInfo +from dataframely._storage.parquet import RULE_METADATA_KEY, SCHEMA_METADATA_KEY +from dataframely.failure import UNKNOWN_SCHEMA_NAME, FailureInfo class MySchema(dy.Schema):