Skip to content
Merged
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ jobs:
uses: prefix-dev/setup-pixi@194d461b21b6c5717c722ffc597fa91ed2ff29fa # v0.9.1
with:
environments: ${{ matrix.environment }}
# FIXME: Remove when `s3_server` fixture does not start a process anymore
post-cleanup: ${{ matrix.os != 'windows-latest' }}
- name: Install repository
run: pixi run -e ${{ matrix.environment }} postinstall
- name: Run pytest
Expand Down
11 changes: 7 additions & 4 deletions dataframely/_storage/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

import polars as pl
from fsspec import AbstractFileSystem, url_to_fs

from dataframely._compat import deltalake

Expand Down Expand Up @@ -82,7 +83,8 @@ def write_collection(
serialized_schemas: dict[str, str],
**kwargs: Any,
) -> None:
uri = Path(kwargs.pop("target"))
uri = kwargs.pop("target")
fs: AbstractFileSystem = url_to_fs(uri)[0]

# The collection schema is serialized as part of the member parquet metadata
kwargs["metadata"] = kwargs.get("metadata", {}) | {
Expand All @@ -93,19 +95,20 @@ def write_collection(
self.write_frame(
lf.collect(),
serialized_schema=serialized_schemas[key],
target=uri / key,
target=fs.sep.join([uri, key]),
**kwargs,
)

def scan_collection(
self, members: Iterable[str], **kwargs: Any
) -> tuple[dict[str, pl.LazyFrame], list[SerializedCollection | None]]:
uri = Path(kwargs.pop("source"))
uri = kwargs.pop("source")
fs: AbstractFileSystem = url_to_fs(uri)[0]

data = {}
collection_types = []
for key in members:
member_uri = uri / key
member_uri = fs.sep.join([uri, key])
if not deltalake.DeltaTable.is_deltatable(str(member_uri)):
continue
table = _to_delta_table(member_uri)
Expand Down
67 changes: 46 additions & 21 deletions dataframely/_storage/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Iterable
from pathlib import Path
from typing import Any

import polars as pl
from fsspec import AbstractFileSystem, url_to_fs

from ._base import (
SerializedCollection,
Expand Down Expand Up @@ -61,23 +61,27 @@ def read_frame(self, **kwargs: Any) -> tuple[pl.DataFrame, SerializedSchema | No
return df, metadata

# ------------------------------ Collections ---------------------------------------

def sink_collection(
self,
dfs: dict[str, pl.LazyFrame],
serialized_collection: SerializedCollection,
serialized_schemas: dict[str, str],
**kwargs: Any,
) -> None:
path = Path(kwargs.pop("directory"))
path = str(kwargs.pop("directory"))

# The collection schema is serialized as part of the member parquet metadata
kwargs["metadata"] = kwargs.get("metadata", {}) | {
COLLECTION_METADATA_KEY: serialized_collection
}

fs: AbstractFileSystem = url_to_fs(path)[0]
for key, lf in dfs.items():
destination = (
path / key if "partition_by" in kwargs else path / f"{key}.parquet"
fs.sep.join([path, key])
if "partition_by" in kwargs
else fs.sep.join([path, f"{key}.parquet"])
)
self.sink_frame(
lf,
Expand All @@ -93,16 +97,19 @@ def write_collection(
serialized_schemas: dict[str, str],
**kwargs: Any,
) -> None:
path = Path(kwargs.pop("directory"))
path = str(kwargs.pop("directory"))

# The collection schema is serialized as part of the member parquet metadata
kwargs["metadata"] = kwargs.get("metadata", {}) | {
COLLECTION_METADATA_KEY: serialized_collection
}

fs: AbstractFileSystem = url_to_fs(path)[0]
for key, lf in dfs.items():
destination = (
path / key if "partition_by" in kwargs else path / f"{key}.parquet"
fs.sep.join([path, key])
if "partition_by" in kwargs
else fs.sep.join([path, f"{key}.parquet"])
)
self.write_frame(
lf.collect(),
Expand All @@ -114,53 +121,71 @@ def write_collection(
def scan_collection(
self, members: Iterable[str], **kwargs: Any
) -> tuple[dict[str, pl.LazyFrame], list[SerializedCollection | None]]:
path = Path(kwargs.pop("directory"))
path = str(kwargs.pop("directory"))
return self._collection_from_parquet(
path=path, members=members, scan=True, **kwargs
)

def read_collection(
self, members: Iterable[str], **kwargs: Any
) -> tuple[dict[str, pl.LazyFrame], list[SerializedCollection | None]]:
path = Path(kwargs.pop("directory"))
path = str(kwargs.pop("directory"))
return self._collection_from_parquet(
path=path, members=members, scan=False, **kwargs
)

def _collection_from_parquet(
self, path: Path, members: Iterable[str], scan: bool, **kwargs: Any
self, path: str, members: Iterable[str], scan: bool, **kwargs: Any
) -> tuple[dict[str, pl.LazyFrame], list[SerializedCollection | None]]:
# Utility method encapsulating the logic that is common
# between lazy and eager reads
data = {}
collection_types = []

fs: AbstractFileSystem = url_to_fs(path)[0]
for key in members:
if (source_path := self._member_source_path(path, key)) is not None:
if (source_path := self._member_source_path(path, fs, key)) is not None:
is_file = fs.isfile(source_path)
scan_path = source_path if is_file else fs.sep.join([source_path, ""])
data[key] = (
pl.scan_parquet(source_path, **kwargs)
pl.scan_parquet(scan_path, **kwargs)
if scan
else pl.read_parquet(source_path, **kwargs).lazy()
else pl.read_parquet(scan_path, **kwargs).lazy()
)
if source_path.is_file():
if is_file:
collection_types.append(_read_serialized_collection(source_path))
else:
for file in source_path.glob("**/*.parquet"):
collection_types.append(_read_serialized_collection(file))
prefix = (
""
if fs.protocol == "file"
else (
f"{fs.protocol}://"
if isinstance(fs.protocol, str)
else f"{fs.protocol[0]}://"
)
)
for file in fs.glob(fs.sep.join([source_path, "**", "*.parquet"])):
collection_types.append(
_read_serialized_collection(f"{prefix}{file}")
)

# Backward compatibility: If the parquets do not have schema information,
# fall back to looking for schema.json
if not any(collection_types) and (schema_file := path / "schema.json").exists():
collection_types.append(schema_file.read_text())
if not any(collection_types) and fs.exists(
schema_file := fs.sep.join([path, "schema.json"])
):
collection_types.append(fs.read_text(schema_file))

return data, collection_types

@classmethod
def _member_source_path(cls, base_path: Path, name: str) -> Path | None:
if (path := base_path / name).exists() and base_path.is_dir():
def _member_source_path(
cls, base_path: str, fs: AbstractFileSystem, name: str
) -> str | None:
if fs.exists(path := fs.sep.join([base_path, name])) and fs.isdir(base_path):
# We assume that the member is stored as a hive-partitioned dataset
return path
if (path := base_path / f"{name}.parquet").exists():
if fs.exists(path := fs.sep.join([base_path, f"{name}.parquet"])):
# We assume that the member is stored as a single parquet file
return path
return None
Expand Down Expand Up @@ -229,11 +254,11 @@ def scan_failure_info(
return lf, serialized_rules, serialized_schema


def _read_serialized_collection(path: Path) -> SerializedCollection | None:
def _read_serialized_collection(path: str) -> SerializedCollection | None:
meta = pl.read_parquet_metadata(path)
return meta.get(COLLECTION_METADATA_KEY)


def _read_serialized_schema(path: Path) -> SerializedSchema | None:
def _read_serialized_schema(path: str) -> SerializedSchema | None:
meta = pl.read_parquet_metadata(path)
return meta.get(SCHEMA_METADATA_KEY)
Loading
Loading