diff --git a/pyproject.toml b/pyproject.toml index 96eb566..8543942 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ dynamic = [ "description", "version" ] dependencies = [ "numpy" ] optional-dependencies.accel = [ "numba" ] +optional-dependencies.dask = [ "dask>=2025.3" ] optional-dependencies.doc = [ "furo", "pytest", @@ -29,7 +30,8 @@ optional-dependencies.doc = [ "sphinx-autodoc-typehints", "sphinx-autofixture", ] -optional-dependencies.full = [ "dask", "fast-array-utils[accel,sparse]", "h5py", "zarr" ] +optional-dependencies.full = [ "fast-array-utils[accel,dask,h5py,sparse,zarr]" ] +optional-dependencies.h5py = [ "h5py" ] optional-dependencies.sparse = [ "scipy>=1.8" ] optional-dependencies.test = [ "anndata", @@ -44,6 +46,7 @@ optional-dependencies.test-min = [ "pytest-doctestplus", ] optional-dependencies.testing = [ "packaging" ] +optional-dependencies.zarr = [ "zarr" ] urls.'Documentation' = "https://icb-fast-array-utils.readthedocs-hosted.com/" urls.'Issue Tracker' = "https://github.com/scverse/fast-array-utils/issues" urls.'Source Code' = "https://github.com/scverse/fast-array-utils" diff --git a/src/fast_array_utils/__init__.py b/src/fast_array_utils/__init__.py index ba413cf..f73434b 100644 --- a/src/fast_array_utils/__init__.py +++ b/src/fast_array_utils/__init__.py @@ -15,9 +15,7 @@ from __future__ import annotations -from . import _patches, conv, stats, types +from . import conv, stats, types __all__ = ["conv", "stats", "types"] - -_patches.patch_dask() diff --git a/src/fast_array_utils/_checks.py b/src/fast_array_utils/_checks.py new file mode 100644 index 0000000..21f8a49 --- /dev/null +++ b/src/fast_array_utils/_checks.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from functools import cache, wraps +from importlib.metadata import version +from typing import TYPE_CHECKING + +from packaging.version import Version + +from . import types + + +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Concatenate, ParamSpec, TypeVar + + _DA = TypeVar("_DA", bound=types.DaskArray) + _P = ParamSpec("_P") + _R = TypeVar("_R") + + +__all__ = ["check_dask_sparray_support"] + + +@cache +def _dask_supports_sparray() -> bool: + return Version(version("dask")) >= Version("2025.3") + + +def check_dask_sparray_support( + func: Callable[Concatenate[_DA, _P], _R], +) -> Callable[Concatenate[_DA, _P], _R]: + """Check that Dask isn’t too old when trying to use it with `scipy.sparse.sparray`s.""" + + @wraps(func) + def decorated(arr: _DA, *args: _P.args, **kwargs: _P.kwargs) -> _R: + if ( + isinstance(arr, types.DaskArray) + and isinstance(arr._meta, types.sparray) # noqa: SLF001 + and not _dask_supports_sparray() + ): + msg = "dask < 2025.3 does not support `scipy.sparse.sparray`s" + raise RuntimeError(msg) + return func(arr, *args, **kwargs) + + return decorated diff --git a/src/fast_array_utils/_patches.py b/src/fast_array_utils/_patches.py deleted file mode 100644 index d5dd98d..0000000 --- a/src/fast_array_utils/_patches.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-License-Identifier: MPL-2.0 -from __future__ import annotations - -import numpy as np - - -# TODO(flying-sheep): upstream -# https://github.com/dask/dask/issues/11749 -def patch_dask() -> None: # pragma: no cover - """Patch dask to support sparse arrays. - - See - """ - try: - # Other lookup candidates: tensordot_lookup and take_lookup - from dask.array.dispatch import concatenate_lookup - from scipy.sparse import sparray, spmatrix - except ImportError: - return # No need to patch if dask or scipy is not installed - - # Avoid patch if already patched or upstream support has been added - if concatenate_lookup.dispatch(sparray) is not np.concatenate: - return - - concatenate = concatenate_lookup.dispatch(spmatrix) - concatenate_lookup.register(sparray, concatenate) diff --git a/src/fast_array_utils/conv/__init__.py b/src/fast_array_utils/conv/__init__.py index 5cd1e4e..5084615 100644 --- a/src/fast_array_utils/conv/__init__.py +++ b/src/fast_array_utils/conv/__init__.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, overload +from .._checks import check_dask_sparray_support from ..typing import CpuArray, DiskArray, GpuArray # noqa: TC001 from ._to_dense import to_dense_ @@ -47,6 +48,7 @@ def to_dense( ) -> NDArray[Any]: ... +@check_dask_sparray_support def to_dense( x: CpuArray | GpuArray diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 8dabc66..8fbd5d0 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, overload +from .._checks import check_dask_sparray_support from .._validation import validate_axis from ..typing import CpuArray, DiskArray, GpuArray # noqa: TC001 @@ -38,6 +39,7 @@ def is_constant(x: types.CupyArray, /, *, axis: Literal[0, 1]) -> types.CupyArra def is_constant(x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None) -> types.DaskArray: ... +@check_dask_sparray_support def is_constant( x: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray, /, @@ -103,6 +105,7 @@ def mean( ) -> types.DaskArray: ... +@check_dask_sparray_support def mean( x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, @@ -166,6 +169,7 @@ def mean_var( ) -> tuple[types.DaskArray, types.DaskArray]: ... +@check_dask_sparray_support def mean_var( x: CpuArray | GpuArray | types.DaskArray, /, @@ -242,6 +246,7 @@ def sum( ) -> types.DaskArray: ... +@check_dask_sparray_support def sum( x: CpuArray | GpuArray | DiskArray | types.DaskArray, /,