diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d4a46c1e020..e1012283c94 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,10 @@ Breaking changes New Features ~~~~~~~~~~~~ + +- ``chunks='auto'`` is now supported in the ``chunks`` argument of + :py:meth:`Dataset.chunk`. (:issue:`4055`) + By `Andrew Williams `_ - Added :py:func:`xarray.cov` and :py:func:`xarray.corr` (:issue:`3784`, :pull:`3550`, :pull:`4089`). By `Andrew Williams `_ and `Robin Beer `_. - Added :py:meth:`DataArray.polyfit` and :py:func:`xarray.polyval` for fitting polynomials. (:issue:`3349`) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d93f4044a6d..43f6ad9c90e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1707,7 +1707,10 @@ def chunks(self) -> Mapping[Hashable, Tuple[int, ...]]: def chunk( self, chunks: Union[ - None, Number, Mapping[Hashable, Union[None, Number, Tuple[Number, ...]]] + None, + Number, + str, + Mapping[Hashable, Union[None, Number, str, Tuple[Number, ...]]], ] = None, name_prefix: str = "xarray-", token: str = None, @@ -1725,7 +1728,7 @@ def chunk( Parameters ---------- - chunks : int or mapping, optional + chunks : int, 'auto' or mapping, optional Chunk sizes along each dimension, e.g., ``5`` or ``{'x': 5, 'y': 5}``. name_prefix : str, optional @@ -1742,7 +1745,7 @@ def chunk( """ from dask.base import tokenize - if isinstance(chunks, Number): + if isinstance(chunks, (Number, str)): chunks = dict.fromkeys(self.dims, chunks) if chunks is not None: diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 75beb3757ca..6f714fe1825 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1035,6 +1035,14 @@ def test_unify_chunks_shallow_copy(obj, transform): assert_identical(obj, unified) and obj is not obj.unify_chunks() +@pytest.mark.parametrize("obj", [make_da()]) +def test_auto_chunk_da(obj): + actual = obj.chunk("auto").data + expected = obj.data.rechunk("auto") + np.testing.assert_array_equal(actual, expected) + assert actual.chunks == expected.chunks + + def test_map_blocks_error(map_da, map_ds): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1]