Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions .github/workflows/tensorstore-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: TensorStore tests

on:
schedule:
# Every weekday at 03:58 UTC, see https://crontab.guru/
- cron: "58 3 * * 1-5"
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.11"]

steps:
- name: Checkout source
uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
architecture: x64

- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v2

- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install -e '.[test]' 'tensorstore'

- name: Run tests
run: |
# exclude tests that rely on the nchunks_initialized array attribute
pytest -k "not test_resume"
env:
CUBED_STORAGE_NAME: tensorstore
2 changes: 2 additions & 0 deletions cubed/backend_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ def backend_array_to_numpy_array(arr):


def numpy_array_to_backend_array(arr, *, dtype=None):
if isinstance(arr, dict):
return {k: namespace.asarray(v, dtype=dtype) for k, v in arr.items()}
return namespace.asarray(arr, dtype=dtype)
5 changes: 5 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,11 @@ def merge_chunks(x, chunks):


def _copy_chunk(e, x, target_chunks=None, block_id=None):
if isinstance(x.zarray, dict):
return {
k: numpy_array_to_backend_array(v[get_item(target_chunks, block_id)])
for k, v in x.zarray.items()
}
out = x.zarray[get_item(target_chunks, block_id)]
out = numpy_array_to_backend_array(out)
return out
Expand Down
4 changes: 4 additions & 0 deletions cubed/storage/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def open_backend_array(
from cubed.storage.backends.zarr_python import open_zarr_array

open_func = open_zarr_array
elif storage_name == "tensorstore":
from cubed.storage.backends.tensorstore import open_tensorstore_array

open_func = open_tensorstore_array
else:
raise ValueError(f"Unrecognized storage name: {storage_name}")

Expand Down
155 changes: 155 additions & 0 deletions cubed/storage/backends/tensorstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import dataclasses
import math
from typing import Any, Dict, Optional

import numpy as np
import tensorstore

from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store
from cubed.utils import join_path


@dataclasses.dataclass(frozen=True)
class TensorStoreArray:
array: tensorstore.TensorStore

@property
def shape(self) -> tuple[int, ...]:
return self.array.shape

@property
def dtype(self) -> np.dtype:
return self.array.dtype.numpy_dtype

@property
def chunks(self) -> tuple[int, ...]:
return self.array.chunk_layout.read_chunk.shape or ()

@property
def ndim(self) -> int:
return len(self.shape)

@property
def size(self) -> int:
return math.prod(self.shape)

@property
def oindex(self):
return self.array.oindex

def __getitem__(self, key):
# read eagerly
return self.array.__getitem__(key).read().result()

def __setitem__(self, key, value):
self.array.__setitem__(key, value)


class TensorStoreGroup(dict):
def __init__(
self,
shape: Optional[T_Shape] = None,
dtype: Optional[T_DType] = None,
chunks: Optional[T_RegularChunks] = None,
):
dict.__init__(self)
self.shape = shape
self.dtype = dtype
self.chunks = chunks

def __getitem__(self, key):
if isinstance(key, str):
return super().__getitem__(key)
return {field: zarray[key] for field, zarray in self.items()}

def set_basic_selection(self, selection, value, fields=None):
self[fields][selection] = value


def encode_dtype(d):
if d.fields is None:
return d.str
else:
return d.descr


def get_metadata(dtype, chunks):
metadata = {}
if dtype is not None:
dtype = np.dtype(dtype)
metadata["dtype"] = encode_dtype(dtype)
if chunks is not None:
if isinstance(chunks, int):
chunks = (chunks,)
metadata["chunks"] = chunks
return metadata


def open_tensorstore_array(
store: T_Store,
mode: str,
*,
shape: Optional[T_Shape] = None,
dtype: Optional[T_DType] = None,
chunks: Optional[T_RegularChunks] = None,
path: Optional[str] = None,
**kwargs,
):
store = str(store) # TODO: check if Path or str

spec: Dict[str, Any]
if "://" in store:
spec = {"driver": "zarr", "kvstore": store}
else:
spec = {
"driver": "zarr",
"kvstore": {"driver": "file", "path": store},
"path": path or "",
}

if mode == "r":
open_kwargs = dict(read=True, open=True)
elif mode == "r+":
open_kwargs = dict(read=True, write=True, open=True)
elif mode == "a":
open_kwargs = dict(read=True, write=True, open=True, create=True)
elif mode == "w":
open_kwargs = dict(read=True, write=True, create=True, delete_existing=True)
elif mode == "w-":
open_kwargs = dict(read=True, write=True, create=True)
else:
raise ValueError(f"Mode not supported: {mode}")

if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None:
metadata = get_metadata(dtype, chunks)
if metadata:
spec["metadata"] = metadata

return TensorStoreArray(
tensorstore.open(
spec,
shape=shape,
dtype=dtype,
**open_kwargs,
).result()
)
else:
ret = TensorStoreGroup(shape=shape, dtype=dtype, chunks=chunks)
for field in dtype.fields:
field_path = field if path is None else join_path(path, field)
spec["path"] = field_path

field_dtype, _ = dtype.fields[field]
metadata = get_metadata(field_dtype, chunks)
if metadata:
spec["metadata"] = metadata

ret[field] = TensorStoreArray(
tensorstore.open(
spec,
shape=shape,
dtype=field_dtype,
**open_kwargs,
).result()
)
return ret
4 changes: 2 additions & 2 deletions cubed/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@


def test_groupby_reduction_axis0():
a = xp.full((4 * 6, 5), 7, dtype=nxp.int32, chunks=(4, 2))
a = xp.full((4 * 6, 5), 7.0, chunks=(4, 2))
b = xp.asarray([0, 1, 0, 1] * 6, chunks=(4,))
c = mean_groupby_reduction(a, b, axis=0, num_groups=2)
assert_array_equal(c.compute(), np.full((2, 5), 7))


def test_groupby_reduction_axis1():
a = xp.full((5, 4 * 6), 7, dtype=nxp.int32, chunks=(2, 4))
a = xp.full((5, 4 * 6), 7.0, chunks=(2, 4))
b = xp.asarray([0, 1, 0, 1] * 6, chunks=(4,))
c = mean_groupby_reduction(a, b, axis=1, num_groups=2)
assert_array_equal(c.compute(), np.full((5, 2), 7))
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-tenacity.*]
ignore_missing_imports = True
[mypy-tensorstore.*]
ignore_missing_imports = True
[mypy-tlz.*]
ignore_missing_imports = True
[mypy-toolz.*]
Expand Down