Skip to content

Commit 910ac53

Browse files
committed
Rebase Ryan's work in zarr-developers#2031
1 parent 6811c94 commit 910ac53

File tree

5 files changed

+139
-0
lines changed

5 files changed

+139
-0
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ repos:
3636
- typing_extensions
3737
- universal-pathlib
3838
- obstore>=0.5.1
39+
- pyarrow
3940
# Tests
4041
- pytest
4142
- hypothesis

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737
'numcodecs[crc32c]>=0.14',
3838
'typing_extensions>=4.9',
3939
'donfig>=0.8',
40+
'pyarrow',
4041
]
4142

4243
dynamic = [

src/zarr/codecs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from zarr.codecs.arrow import ArrowRecordBatchCodec
34
from zarr.codecs.blosc import BloscCname, BloscCodec, BloscShuffle
45
from zarr.codecs.bytes import BytesCodec, Endian
56
from zarr.codecs.crc32c_ import Crc32cCodec
@@ -34,6 +35,7 @@
3435
from zarr.registry import register_codec
3536

3637
__all__ = [
38+
"ArrowRecordBatchCodec",
3739
"BloscCname",
3840
"BloscCodec",
3941
"BloscShuffle",
@@ -49,6 +51,7 @@
4951
"ZstdCodec",
5052
]
5153

54+
register_codec("arrow", ArrowRecordBatchCodec)
5255
register_codec("blosc", BloscCodec)
5356
register_codec("bytes", BytesCodec)
5457

src/zarr/codecs/arrow.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING
5+
6+
import pyarrow as pa
7+
8+
from zarr.abc.codec import ArrayBytesCodec
9+
from zarr.core.array_spec import ArraySpec
10+
from zarr.core.buffer import Buffer, NDBuffer
11+
from zarr.core.common import JSON, parse_named_configuration
12+
13+
if TYPE_CHECKING:
14+
from typing_extensions import Self
15+
16+
CHUNK_FIELD_NAME = "zarr_chunk"
17+
18+
19+
@dataclass(frozen=True)
20+
class ArrowRecordBatchCodec(ArrayBytesCodec):
21+
def __init__(self) -> None:
22+
pass
23+
24+
@classmethod
25+
def from_dict(cls, data: dict[str, JSON]) -> Self:
26+
_, configuration_parsed = parse_named_configuration(
27+
data, "arrow", require_configuration=False
28+
)
29+
configuration_parsed = configuration_parsed or {}
30+
return cls(**configuration_parsed)
31+
32+
def to_dict(self) -> dict[str, JSON]:
33+
return {"name": "arrow"}
34+
35+
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
36+
return self
37+
38+
async def _decode_single(
39+
self,
40+
chunk_bytes: Buffer,
41+
chunk_spec: ArraySpec,
42+
) -> NDBuffer:
43+
assert isinstance(chunk_bytes, Buffer)
44+
45+
# TODO: make this compatible with buffer prototype
46+
arrow_buffer = pa.py_buffer(chunk_bytes.to_bytes())
47+
with pa.ipc.open_stream(arrow_buffer) as reader:
48+
batches = [b for b in reader]
49+
assert len(batches) == 1
50+
arrow_array = batches[0][CHUNK_FIELD_NAME]
51+
chunk_array = chunk_spec.prototype.nd_buffer.from_ndarray_like(
52+
arrow_array.to_numpy(zero_copy_only=False)
53+
)
54+
55+
# ensure correct chunk shape
56+
if chunk_array.shape != chunk_spec.shape:
57+
chunk_array = chunk_array.reshape(
58+
chunk_spec.shape,
59+
)
60+
return chunk_array
61+
62+
async def _encode_single(
63+
self,
64+
chunk_array: NDBuffer,
65+
chunk_spec: ArraySpec,
66+
) -> Buffer | None:
67+
assert isinstance(chunk_array, NDBuffer)
68+
arrow_array = pa.array(chunk_array.as_ndarray_like().ravel())
69+
rb = pa.record_batch([arrow_array], names=[CHUNK_FIELD_NAME])
70+
# TODO: allocate buffer differently
71+
sink = pa.BufferOutputStream()
72+
with pa.ipc.new_stream(sink, rb.schema) as writer:
73+
writer.write_batch(rb)
74+
return chunk_spec.prototype.buffer.from_bytes(memoryview(sink.getvalue()))
75+
76+
def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int:
77+
raise ValueError("Don't know how to compute encoded size!")

tests/test_codecs/test_arrow.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import numpy as np
2+
import pytest
3+
4+
import zarr
5+
from zarr.abc.store import Store
6+
from zarr.codecs import ArrowRecordBatchCodec
7+
from zarr.storage import StorePath
8+
9+
10+
@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"])
11+
@pytest.mark.parametrize(
12+
"dtype",
13+
[
14+
"uint8",
15+
"uint16",
16+
"uint32",
17+
"uint64",
18+
"int8",
19+
"int16",
20+
"int32",
21+
"int64",
22+
"float32",
23+
"float64",
24+
],
25+
)
26+
def test_arrow_standard_dtypes(store: Store, dtype) -> None:
27+
data = np.arange(0, 256, dtype=dtype).reshape((16, 16))
28+
29+
a = zarr.create_array(
30+
StorePath(store, path="arrow"),
31+
shape=data.shape,
32+
chunks=(16, 16),
33+
dtype=data.dtype,
34+
fill_value=0,
35+
serializer=ArrowRecordBatchCodec(),
36+
)
37+
38+
a[:, :] = data
39+
assert np.array_equal(data, a[:, :])
40+
41+
42+
@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"])
43+
def test_arrow_vlen_string(store: Store) -> None:
44+
strings = ["hello", "world", "this", "is", "a", "test"]
45+
data = np.array(strings).reshape((2, 3))
46+
47+
a = zarr.create_array(
48+
StorePath(store, path="arrow"),
49+
shape=data.shape,
50+
chunks=data.shape,
51+
dtype=data.dtype,
52+
fill_value=0,
53+
serializer=ArrowRecordBatchCodec(),
54+
)
55+
56+
a[:, :] = data
57+
assert np.array_equal(data, a[:, :])

0 commit comments

Comments
 (0)