diff --git a/cubed/core/ops.py b/cubed/core/ops.py index ee92a09e8..bc926db97 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -621,6 +621,34 @@ def rechunk(x, chunks, target_store=None): return Array(name, pipeline.target_array, spec, plan) +def merge_chunks(x, chunks): + target_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype) + + target_chunksize = to_chunksize(target_chunks) + if len(target_chunksize) != x.ndim: + raise ValueError( + f"Chunks {target_chunksize} must have same number of dimensions as array ({x.ndim})" + ) + if not all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunksize)): + raise ValueError( + f"Chunks {target_chunksize} must be a multiple of array's chunks {x.chunksize}" + ) + + return map_direct( + _copy_chunk, + x, + shape=x.shape, + dtype=x.dtype, + chunks=target_chunks, + extra_projected_mem=0, + target_chunks=target_chunks, + ) + + +def _copy_chunk(e, x, target_chunks=None, block_id=None): + return x.zarray[get_item(target_chunks, block_id)] + + def reduction( x: "Array", func, @@ -666,9 +694,9 @@ def reduction( adjust_chunks=adjust_chunks, ) - # rechunk/reduce along axis in multiple rounds until there's a single block in each reduction axis + # merge/reduce along axis in multiple rounds until there's a single block in each reduction axis while any(n > 1 for i, n in enumerate(result.numblocks) if i in axis): - # rechunk along axis + # merge along axis target_chunks = list(result.chunksize) chunk_mem = chunk_memory(intermediate_dtype, result.chunksize) for i, s in enumerate(result.shape): @@ -680,7 +708,7 @@ def reduction( target_chunks[i] = min(s, x.chunksize[i]) else: # single axis: see how many result chunks fit in max_mem - # factor of 4 is memory for {compressed, uncompressed} x {input, output} (see rechunk.py) + # factor of 4 is memory for {compressed, uncompressed} x {input, output} target_chunk_size = (max_mem - chunk_mem) // (chunk_mem * 4) if target_chunk_size <= 1: raise ValueError( @@ -688,7 +716,7 @@ def reduction( ) target_chunks[i] = min(s, target_chunk_size) _target_chunks = tuple(target_chunks) - result = rechunk(result, _target_chunks) + result = merge_chunks(result, _target_chunks) # reduce chunks (if any axis chunksize is > 1) if any(s > 1 for i, s in enumerate(result.chunksize) if i in axis): diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index 99ea81389..4f00d0a43 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -10,6 +10,7 @@ import cubed import cubed.array_api as xp import cubed.random +from cubed.core.ops import merge_chunks from cubed.extensions.history import HistoryCallback from cubed.extensions.timeline import TimelineVisualizationCallback from cubed.extensions.tqdm import TqdmProgressBar @@ -298,12 +299,12 @@ def test_reduction_multiple_rounds(tmp_path, executor): spec = cubed.Spec(tmp_path, allowed_mem=1000) a = xp.ones((100, 10), dtype=np.uint8, chunks=(1, 10), spec=spec) b = xp.sum(a, axis=0, dtype=np.uint8) - # check that there is > 1 rechunk step - rechunks = [ - n for (n, d) in b.plan.dag.nodes(data=True) if d["op_name"] == "rechunk" + # check that there is > 1 blockwise step (after optimization) + blockwises = [ + n for (n, d) in b.plan.dag.nodes(data=True) if d["op_name"] == "blockwise" ] - assert len(rechunks) > 1 - assert b.plan.max_projected_mem() == 1000 + assert len(blockwises) > 1 + assert b.plan.max_projected_mem() <= 1000 assert_array_equal(b.compute(executor=executor), np.ones((100, 10)).sum(axis=0)) @@ -314,6 +315,23 @@ def test_reduction_not_enough_memory(tmp_path): xp.sum(a, axis=0, dtype=np.uint8) +@pytest.mark.parametrize("target_chunks", [(2, 3), (4, 3), (2, 6), (4, 6)]) +def test_merge_chunks(spec, target_chunks): + a = xp.ones((10, 10), dtype=np.uint8, chunks=(2, 3), spec=spec) + b = merge_chunks(a, target_chunks) + assert b.chunksize == target_chunks + assert_array_equal(b.compute(), np.ones((10, 10))) + + +@pytest.mark.parametrize( + "target_chunks", [(2,), (2, 3, 1), (3, 2), (1, 3), (5, 5), (12, 12)] +) +def test_merge_chunks_fails(spec, target_chunks): + a = xp.ones((10, 10), dtype=np.uint8, chunks=(2, 3), spec=spec) + with pytest.raises(ValueError): + merge_chunks(a, target_chunks) + + def test_compute_multiple(): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2)) b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2))