Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions fsspec/implementations/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
from itertools import chain
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

import fsspec.core

Expand Down Expand Up @@ -104,7 +104,13 @@ def pd(self):
return pd

def __init__(
self, root, fs=None, out_root=None, cache_size=128, categorical_threshold=10
self,
root,
fs=None,
out_root=None,
cache_size=128,
categorical_threshold=10,
engine: Literal["fastparquet", "pyarrow"] = "fastparquet",
):
"""

Expand All @@ -126,11 +132,20 @@ def __init__(
Encode urls as pandas.Categorical to reduce memory footprint if the ratio
of the number of unique urls to total number of refs for each variable
is greater than or equal to this number. (default 10)
engine: Literal["fastparquet","pyarrow"]
Engine choice for reading parquet files. (default is "fastparquet")
"""

from importlib.util import find_spec

if find_spec("pyarrow") is None:
raise ImportError("engine choice `pyarrow` is not installed.")

self.root = root
self.chunk_sizes = {}
self.out_root = out_root or self.root
self.cat_thresh = categorical_threshold
self.engine = engine
self.cache_size = cache_size
self.url = self.root + "/{field}/refs.{record}.parq"
# TODO: derive fs from `root`
Expand Down Expand Up @@ -158,7 +173,7 @@ def open_refs(field, record):
"""cached parquet file loader"""
path = self.url.format(field=field, record=record)
data = io.BytesIO(self.fs.cat_file(path))
df = self.pd.read_parquet(data, engine="fastparquet")
df = self.pd.read_parquet(data, engine=self.engine)
refs = {c: df[c].to_numpy() for c in df.columns}
return refs

Expand Down Expand Up @@ -463,18 +478,28 @@ def write(self, field, record, base_url=None, storage_options=None):

fn = f"{base_url or self.out_root}/{field}/refs.{record}.parq"
self.fs.mkdirs(f"{base_url or self.out_root}/{field}", exist_ok=True)

if self.engine == "pyarrow":
df_backend_kwargs = {}
elif self.engine == "fastparquet":
df_backend_kwargs = {
"stats": False,
"object_encoding": object_encoding,
"has_nulls": has_nulls,
}
else:
raise NotImplementedError(f"{self.engine} not supported")

df.to_parquet(
fn,
engine="fastparquet",
engine=self.engine,
storage_options=storage_options
or getattr(self.fs, "storage_options", None),
compression="zstd",
index=False,
stats=False,
object_encoding=object_encoding,
has_nulls=has_nulls,
# **kwargs,
**df_backend_kwargs,
)

partition.clear()
self._items.pop((field, record))

Expand All @@ -486,6 +511,7 @@ def flush(self, base_url=None, storage_options=None):
base_url: str
Location of the output
"""

# write what we have so far and clear sub chunks
for thing in list(self._items):
if isinstance(thing, tuple):
Expand Down
14 changes: 11 additions & 3 deletions fsspec/implementations/tests/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,25 +761,33 @@ def test_append_parquet(lazy_refs, m):
assert lazy2["data/1"] == b"Adata"


def test_deep_parq(m):
@pytest.mark.parametrize("engine", ["fastparquet", "pyarrow"])
def test_deep_parq(m, engine):
pytest.importorskip("kerchunk")
zarr = pytest.importorskip("zarr")

lz = fsspec.implementations.reference.LazyReferenceMapper.create(
"memory://out.parq", fs=m
"memory://out.parq",
fs=m,
engine=engine,
)
g = zarr.open_group(lz, mode="w")

g2 = g.create_group("instant")
g2.create_dataset(name="one", data=[1, 2, 3])
lz.flush()

lz = fsspec.implementations.reference.LazyReferenceMapper("memory://out.parq", fs=m)
lz = fsspec.implementations.reference.LazyReferenceMapper(
"memory://out.parq", fs=m, engine=engine
)
g = zarr.open_group(lz)
assert g.instant.one[:].tolist() == [1, 2, 3]
assert sorted(_["name"] for _ in lz.ls("")) == [".zgroup", ".zmetadata", "instant"]
assert sorted(_["name"] for _ in lz.ls("instant")) == [
"instant/.zgroup",
"instant/one",
]

assert sorted(_["name"] for _ in lz.ls("instant/one")) == [
"instant/one/.zarray",
"instant/one/0",
Expand Down
Loading