diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 0f19a1b51be..0290775edd7 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -176,6 +176,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, Chunk sizes along each dimension, e.g., ``5``, ``(5, 5)`` or ``{'x': 5, 'y': 5}``. If chunks is provided, it used to load the new DataArray into a dask array. + Chunks can also be set to ``True`` or ``"auto"`` to choose sensible + chunk sizes according to ``dask.config.get("array.chunk-size")`` cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -283,6 +285,9 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) + if chunks in (True, 'auto'): + chunks = (1, 'auto', 'auto') + if cache and (chunks is None): data = indexing.MemoryCachedArray(data) @@ -301,6 +306,30 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, name_prefix = 'open_rasterio-%s' % token if lock is None: lock = RASTERIO_LOCK + + if not attrs.get('is_tiled', False): + msg = "Data store is not tiled. Automatic chunking is not sensible" + raise ValueError(msg) + + import dask.array + if dask.__version__ < '0.18.0': + msg = ("Automatic chunking requires dask.__version__ >= 0.18.0 . " + "You currently have version %s" % dask.__version__) + raise NotImplementedError(msg) + + img = riods._ds + block_shapes = set(img.block_shapes) + block_shape = (1,) + list(block_shapes)[0] + previous_chunks = tuple((c,) for c in block_shape) + shape = (img.count, img.height, img.width) + dtype = img.dtypes[0] + chunks = dask.array.core.normalize_chunks( + chunks, + shape=shape, + previous_chunks=previous_chunks, + dtype=dtype + ) + result = result.chunk(chunks, name_prefix=name_prefix, token=token, lock=lock) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index e83b80a6dd8..6b9407d9932 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1340,7 +1340,6 @@ def test_write_uneven_dask_chunks(self): print(k) assert v.chunks == actual[k].chunks - def test_chunk_encoding(self): # These datasets have no dask chunks. All chunking specified in # encoding @@ -3009,6 +3008,21 @@ def test_chunks(self): ex = expected.sel(band=1).mean(dim='x') assert_allclose(ac, ex) + @requires_dask + def test_chunks_auto(self): + import dask + with dask.config.set({'array.chunk-size': '64kiB'}): + # TODO: enhance create_tmp_geotiff to support tiled images + with create_tmp_geotiff(1024, 2048, 3) as (tmp_file, expected): + with xr.open_rasterio(tmp_file, chunks=True) as actual: + assert actual.chunks[0] == (1, 1, 1) + assert actual.chunks[1] == (256,) * 4 + assert actual.chunks[2] == (256,) * 8 + with xr.open_rasterio(tmp_file, chunks=(3, 'auto', 'auto')) as actual: + assert actual.chunks[0] == (3,) + assert actual.chunks[1] == (128,) * 8 + assert actual.chunks[2] == (128,) * 16 + def test_pickle_rasterio(self): # regression test for https://github.com/pydata/xarray/issues/2121 with create_tmp_geotiff() as (tmp_file, expected):