Skip to content

Commit 784cb28

Browse files
committed
Merge remote-tracking branch 'upstream/v3' into user/tom/fix/v2-compat
2 parents f937468 + 60b4f57 commit 784cb28

33 files changed

+1082
-167
lines changed

.github/workflows/gpu_test.yml

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2+
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3+
4+
name: GPU Test V3
5+
6+
on:
7+
push:
8+
branches: [ v3 ]
9+
pull_request:
10+
branches: [ v3 ]
11+
workflow_dispatch:
12+
13+
env:
14+
LD_LIBRARY_PATH: /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64
15+
16+
concurrency:
17+
group: ${{ github.workflow }}-${{ github.ref }}
18+
cancel-in-progress: true
19+
20+
jobs:
21+
test:
22+
name: py=${{ matrix.python-version }}, np=${{ matrix.numpy-version }}, deps=${{ matrix.dependency-set }}
23+
24+
runs-on: gpu-runner
25+
strategy:
26+
matrix:
27+
python-version: ['3.11']
28+
numpy-version: ['2.0']
29+
dependency-set: ["minimal"]
30+
31+
steps:
32+
- uses: actions/checkout@v4
33+
# - name: cuda-toolkit
34+
# uses: Jimver/[email protected]
35+
# id: cuda-toolkit
36+
# with:
37+
# cuda: '12.4.1'
38+
- name: Set up CUDA
39+
run: |
40+
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb
41+
sudo dpkg -i cuda-keyring_1.1-1_all.deb
42+
sudo apt-get update
43+
sudo apt-get -y install cuda-toolkit-12-6
44+
echo "/usr/local/cuda/bin" >> $GITHUB_PATH
45+
- name: GPU check
46+
run: |
47+
nvidia-smi
48+
echo $PATH
49+
echo $LD_LIBRARY_PATH
50+
nvcc -V
51+
- name: Set up Python
52+
uses: actions/setup-python@v5
53+
with:
54+
python-version: ${{ matrix.python-version }}
55+
cache: 'pip'
56+
- name: Install Hatch and CuPy
57+
run: |
58+
python -m pip install --upgrade pip
59+
pip install hatch
60+
- name: Set Up Hatch Env
61+
run: |
62+
hatch env create gputest.py${{ matrix.python-version }}-${{ matrix.numpy-version }}-${{ matrix.dependency-set }}
63+
hatch env run -e gputest.py${{ matrix.python-version }}-${{ matrix.numpy-version }}-${{ matrix.dependency-set }} list-env
64+
- name: Run Tests
65+
run: |
66+
hatch env run --env gputest.py${{ matrix.python-version }}-${{ matrix.numpy-version }}-${{ matrix.dependency-set }} run-coverage

.github/workflows/releases.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
with:
5656
name: releases
5757
path: dist
58-
- uses: pypa/gh-action-pypi-publish@v1.9.0
58+
- uses: pypa/gh-action-pypi-publish@v1.10.0
5959
with:
6060
user: __token__
6161
password: ${{ secrets.pypi_password }}

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_language_version:
77
python: python3
88
repos:
99
- repo: https://github.com/astral-sh/ruff-pre-commit
10-
rev: 'v0.5.7'
10+
rev: v0.6.3
1111
hooks:
1212
- id: ruff
1313
args: ["--fix", "--show-fixes"]

pyproject.toml

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ jupyter = [
7575
'ipytree>=0.2.2',
7676
'ipywidgets>=8.0.0',
7777
]
78+
gpu = [
79+
"cupy-cuda12x",
80+
]
7881
docs = [
7982
'sphinx',
8083
'sphinx-autobuild>=2021.3.14',
@@ -121,7 +124,7 @@ build.hooks.vcs.version-file = "src/zarr/_version.py"
121124
[tool.hatch.envs.test]
122125
dependencies = [
123126
"numpy~={matrix:numpy}",
124-
"universal_pathlib"
127+
"universal_pathlib",
125128
]
126129
features = ["test", "extra"]
127130

@@ -135,8 +138,34 @@ python = ["3.10", "3.11", "3.12"]
135138
numpy = ["1.24", "1.26", "2.0"]
136139
features = ["optional"]
137140

141+
[[tool.hatch.envs.test.matrix]]
142+
python = ["3.10", "3.11", "3.12"]
143+
numpy = ["1.24", "1.26", "2.0"]
144+
features = ["gpu"]
145+
138146
[tool.hatch.envs.test.scripts]
139147
run-coverage = "pytest --cov-config=pyproject.toml --cov=pkg --cov=tests"
148+
run-coverage-gpu = "pip install cupy-cuda12x && pytest -m gpu --cov-config=pyproject.toml --cov=pkg --cov=tests"
149+
run = "run-coverage --no-cov"
150+
run-verbose = "run-coverage --verbose"
151+
run-mypy = "mypy src"
152+
run-hypothesis = "pytest --hypothesis-profile ci tests/v3/test_properties.py tests/v3/test_store/test_stateful*"
153+
list-env = "pip list"
154+
155+
[tool.hatch.envs.gputest]
156+
dependencies = [
157+
"numpy~={matrix:numpy}",
158+
"universal_pathlib",
159+
]
160+
features = ["test", "extra", "gpu"]
161+
162+
[[tool.hatch.envs.gputest.matrix]]
163+
python = ["3.10", "3.11", "3.12"]
164+
numpy = ["1.24", "1.26", "2.0"]
165+
version = ["minimal"]
166+
167+
[tool.hatch.envs.gputest.scripts]
168+
run-coverage = "pytest -m gpu --cov-config=pyproject.toml --cov=pkg --cov=tests"
140169
run = "run-coverage --no-cov"
141170
run-verbose = "run-coverage --verbose"
142171
run-mypy = "mypy src"
@@ -169,6 +198,7 @@ extend-exclude = [
169198
"buck-out",
170199
"build",
171200
"dist",
201+
"notebooks", # temporary, until we achieve compatibility with ruff ≥ 0.6
172202
"venv",
173203
"docs",
174204
"src/zarr/v2/",
@@ -224,4 +254,8 @@ filterwarnings = [
224254
"error:::zarr.*",
225255
"ignore:PY_SSIZE_T_CLEAN will be required.*:DeprecationWarning",
226256
"ignore:The loop argument is deprecated since Python 3.8.*:DeprecationWarning",
257+
"ignore:Creating a zarr.buffer.gpu.*:UserWarning",
258+
]
259+
markers = [
260+
"gpu: mark a test as requiring CuPy and GPU"
227261
]

src/zarr/codecs/blosc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
from zarr.abc.codec import BytesBytesCodec
1212
from zarr.core.array_spec import ArraySpec
13-
from zarr.core.buffer import Buffer, as_numpy_array_wrapper
13+
from zarr.core.buffer import Buffer
14+
from zarr.core.buffer.cpu import as_numpy_array_wrapper
1415
from zarr.core.common import JSON, parse_enum, parse_named_configuration, to_thread
1516
from zarr.registry import register_codec
1617

src/zarr/codecs/crc32c_.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ async def _decode_single(
3838
crc32_bytes = data[-4:]
3939
inner_bytes = data[:-4]
4040

41+
# Need to do a manual cast until https://github.com/numpy/numpy/issues/26783 is resolved
4142
computed_checksum = np.uint32(crc32c(cast(typing_extensions.Buffer, inner_bytes))).tobytes()
4243
stored_checksum = bytes(crc32_bytes)
4344
if computed_checksum != stored_checksum:

src/zarr/codecs/gzip.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from zarr.abc.codec import BytesBytesCodec
99
from zarr.core.array_spec import ArraySpec
10-
from zarr.core.buffer import Buffer, as_numpy_array_wrapper
10+
from zarr.core.buffer import Buffer
11+
from zarr.core.buffer.cpu import as_numpy_array_wrapper
1112
from zarr.core.common import JSON, parse_named_configuration, to_thread
1213
from zarr.registry import register_codec
1314

src/zarr/codecs/sharding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class _ShardIndex(NamedTuple):
101101

102102
@property
103103
def chunks_per_shard(self) -> ChunkCoords:
104-
result = tuple(self.offsets_and_lengths[:-1])
104+
result = tuple(self.offsets_and_lengths.shape[0:-1])
105105
# The cast is required until https://github.com/numpy/numpy/pull/27211 is merged
106106
return cast(ChunkCoords, result)
107107

src/zarr/codecs/zstd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
from zarr.abc.codec import BytesBytesCodec
1111
from zarr.core.array_spec import ArraySpec
12-
from zarr.core.buffer import Buffer, as_numpy_array_wrapper
12+
from zarr.core.buffer import Buffer
13+
from zarr.core.buffer.cpu import as_numpy_array_wrapper
1314
from zarr.core.common import JSON, parse_named_configuration, to_thread
1415
from zarr.registry import register_codec
1516

src/zarr/core/array.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,15 +513,24 @@ async def _set_selection(
513513

514514
# check value shape
515515
if np.isscalar(value):
516-
value = np.asanyarray(value, dtype=self.metadata.dtype)
516+
array_like = prototype.buffer.create_zero_length().as_array_like()
517+
if isinstance(array_like, np._typing._SupportsArrayFunc):
518+
# TODO: need to handle array types that don't support __array_function__
519+
# like PyTorch and JAX
520+
array_like_ = cast(np._typing._SupportsArrayFunc, array_like)
521+
value = np.asanyarray(value, dtype=self.metadata.dtype, like=array_like_)
517522
else:
518523
if not hasattr(value, "shape"):
519524
value = np.asarray(value, self.metadata.dtype)
520525
# assert (
521526
# value.shape == indexer.shape
522527
# ), f"shape of value doesn't match indexer shape. Expected {indexer.shape}, got {value.shape}"
523528
if not hasattr(value, "dtype") or value.dtype.name != self.metadata.dtype.name:
524-
value = np.array(value, dtype=self.metadata.dtype, order="A")
529+
if hasattr(value, "astype"):
530+
# Handle things that are already NDArrayLike more efficiently
531+
value = value.astype(dtype=self.metadata.dtype, order="A")
532+
else:
533+
value = np.array(value, dtype=self.metadata.dtype, order="A")
525534
value = cast(NDArrayLike, value)
526535
# We accept any ndarray like object from the user and convert it
527536
# to a NDBuffer (or subclass). From this point onwards, we only pass

src/zarr/core/buffer/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from zarr.core.buffer.core import (
2+
ArrayLike,
3+
Buffer,
4+
BufferPrototype,
5+
NDArrayLike,
6+
NDBuffer,
7+
default_buffer_prototype,
8+
)
9+
from zarr.core.buffer.cpu import numpy_buffer_prototype
10+
11+
__all__ = [
12+
"ArrayLike",
13+
"Buffer",
14+
"NDArrayLike",
15+
"NDBuffer",
16+
"BufferPrototype",
17+
"default_buffer_prototype",
18+
"numpy_buffer_prototype",
19+
]

0 commit comments

Comments
 (0)