Skip to content

Commit 3f29551

Browse files
dcheriancrusaderky
authored andcommitted
map_blocks (#3276)
* map_block attempt 2 * Address reviews: errors, args + kwargs support. * Works with datasets! * remove wrong comment. * Support chunks. * infer template. * cleanup * cleanup2 * api.rst * simple shape change error check. * Make test more complicated. * Fix for when user function doesn't set DataArray.name * Now _to_temp_dataset works. * Add whats-new * chunks kwarg makes no sense right now. * review feedback: 1. skip index graph nodes. 2. var → name 3. quicker dataarray creation. 4. Add restrictions to docstring. 5. rename chunk construction task. 6. error when non-xarray object is returned. 7. restore non-coord dims. review * Support nondim coords in make_meta. * Add Dataset.unify_chunks * doc updates. * minor. * update comment. * More complicated test dataset. Tests fail :X * Don't know why compute is needed. * work with DataArray nondim coords. * fastpath unify_chunks * comment. * much improved tests. * Change args, kwargs syntax. * Add dataset, dataarray methods. * api.rst * docstrings. * Fix unify_chunks. * Move assert_chunks_equal to xarray.testing. * minor changes. * Better error handling when inferring returned object * wip * wip * better to_array * Docstrings + nicer error message. * remove unify_chunks in map_blocks + better tests. * typing for unify_chunks * address more review comments. * more unify_chunks tests. * Just use dask.core.utils.meta_from_array * get tests working. assert_equal needs a lot of fixing. * more unify_chunks test. * assert_chunks_equal fixes. * copy over meta_from_array. * minor fixes. * raise chunks error earlier and test for map_blocks raising chunk error * fix. * Type annotations * py35 compat * make sure unify_chunks does not compute. * Make tests functional by call compute before assert_equal * Update whats-new * Work with attributes. * Support attrs and name changes. * more assert_equal * test changing coord attribute * fix whats new * rework tests to use fixtures (kind of) * more review changes. * cleanup * more review feedback. * fix unify_chunks. * read dask_array_compat :) * Dask 1.2.0 compat. * documentation polish * make_meta reflow * cosmetic * polish * Fix tests * isort * isort * Add func call to docstrings.
1 parent 291cb80 commit 3f29551

15 files changed

+910
-26
lines changed

doc/api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Top-level functions
3030
zeros_like
3131
ones_like
3232
dot
33+
map_blocks
3334

3435
Dataset
3536
=======
@@ -499,6 +500,8 @@ Dataset methods
499500
Dataset.persist
500501
Dataset.load
501502
Dataset.chunk
503+
Dataset.unify_chunks
504+
Dataset.map_blocks
502505
Dataset.filter_by_attrs
503506
Dataset.info
504507

@@ -529,6 +532,8 @@ DataArray methods
529532
DataArray.persist
530533
DataArray.load
531534
DataArray.chunk
535+
DataArray.unify_chunks
536+
DataArray.map_blocks
532537

533538
GroupBy objects
534539
===============
@@ -629,6 +634,7 @@ Testing
629634
testing.assert_equal
630635
testing.assert_identical
631636
testing.assert_allclose
637+
testing.assert_chunks_equal
632638

633639
Exceptions
634640
==========

doc/whats-new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ Breaking changes
4949
New functions/methods
5050
~~~~~~~~~~~~~~~~~~~~~
5151

52+
- Added :py:func:`~xarray.map_blocks`, modeled after :py:func:`dask.array.map_blocks`.
53+
Also added :py:meth:`Dataset.unify_chunks`, :py:meth:`DataArray.unify_chunks` and
54+
:py:meth:`testing.assert_chunks_equal`. By `Deepak Cherian <https://github.com/dcherian>`_
55+
and `Guido Imperiale <https://github.com/crusaderky>`_.
56+
5257
Enhancements
5358
~~~~~~~~~~~~
5459

xarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .core.dataarray import DataArray
1818
from .core.merge import merge, MergeError
1919
from .core.options import set_options
20+
from .core.parallel import map_blocks
2021

2122
from .backends.api import (
2223
open_dataset,

xarray/coding/times.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
unpack_for_encoding,
2323
)
2424

25-
2625
# standard calendars recognized by cftime
2726
_STANDARD_CALENDARS = {"standard", "gregorian", "proleptic_gregorian"}
2827

xarray/core/dask_array_compat.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from distutils.version import LooseVersion
2+
3+
import dask.array as da
4+
import numpy as np
5+
from dask import __version__ as dask_version
6+
7+
if LooseVersion(dask_version) >= LooseVersion("2.0.0"):
8+
meta_from_array = da.utils.meta_from_array
9+
else:
10+
# Copied from dask v2.4.0
11+
# Used under the terms of Dask's license, see licenses/DASK_LICENSE.
12+
import numbers
13+
14+
def meta_from_array(x, ndim=None, dtype=None):
15+
""" Normalize an array to appropriate meta object
16+
17+
Parameters
18+
----------
19+
x: array-like, callable
20+
Either an object that looks sufficiently like a Numpy array,
21+
or a callable that accepts shape and dtype keywords
22+
ndim: int
23+
Number of dimensions of the array
24+
dtype: Numpy dtype
25+
A valid input for ``np.dtype``
26+
27+
Returns
28+
-------
29+
array-like with zero elements of the correct dtype
30+
"""
31+
# If using x._meta, x must be a Dask Array, some libraries (e.g. zarr)
32+
# implement a _meta attribute that are incompatible with Dask Array._meta
33+
if hasattr(x, "_meta") and isinstance(x, da.Array):
34+
x = x._meta
35+
36+
if dtype is None and x is None:
37+
raise ValueError("You must specify the meta or dtype of the array")
38+
39+
if np.isscalar(x):
40+
x = np.array(x)
41+
42+
if x is None:
43+
x = np.ndarray
44+
45+
if isinstance(x, type):
46+
x = x(shape=(0,) * (ndim or 0), dtype=dtype)
47+
48+
if (
49+
not hasattr(x, "shape")
50+
or not hasattr(x, "dtype")
51+
or not isinstance(x.shape, tuple)
52+
):
53+
return x
54+
55+
if isinstance(x, list) or isinstance(x, tuple):
56+
ndims = [
57+
0
58+
if isinstance(a, numbers.Number)
59+
else a.ndim
60+
if hasattr(a, "ndim")
61+
else len(a)
62+
for a in x
63+
]
64+
a = [a if nd == 0 else meta_from_array(a, nd) for a, nd in zip(x, ndims)]
65+
return a if isinstance(x, list) else tuple(x)
66+
67+
if ndim is None:
68+
ndim = x.ndim
69+
70+
try:
71+
meta = x[tuple(slice(0, 0, None) for _ in range(x.ndim))]
72+
if meta.ndim != ndim:
73+
if ndim > x.ndim:
74+
meta = meta[
75+
(Ellipsis,) + tuple(None for _ in range(ndim - meta.ndim))
76+
]
77+
meta = meta[tuple(slice(0, 0, None) for _ in range(meta.ndim))]
78+
elif ndim == 0:
79+
meta = meta.sum()
80+
else:
81+
meta = meta.reshape((0,) * ndim)
82+
except Exception:
83+
meta = np.empty((0,) * ndim, dtype=dtype or x.dtype)
84+
85+
if np.isscalar(meta):
86+
meta = np.array(meta)
87+
88+
if dtype and meta.dtype != dtype:
89+
meta = meta.astype(dtype)
90+
91+
return meta

xarray/core/dataarray.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Optional,
1515
Sequence,
1616
Tuple,
17+
TypeVar,
1718
Union,
1819
cast,
1920
overload,
@@ -63,6 +64,8 @@
6364
)
6465

6566
if TYPE_CHECKING:
67+
T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset)
68+
6669
try:
6770
from dask.delayed import Delayed
6871
except ImportError:
@@ -3038,6 +3041,79 @@ def integrate(
30383041
ds = self._to_temp_dataset().integrate(dim, datetime_unit)
30393042
return self._from_temp_dataset(ds)
30403043

3044+
def unify_chunks(self) -> "DataArray":
3045+
""" Unify chunk size along all chunked dimensions of this DataArray.
3046+
3047+
Returns
3048+
-------
3049+
3050+
DataArray with consistent chunk sizes for all dask-array variables
3051+
3052+
See Also
3053+
--------
3054+
3055+
dask.array.core.unify_chunks
3056+
"""
3057+
ds = self._to_temp_dataset().unify_chunks()
3058+
return self._from_temp_dataset(ds)
3059+
3060+
def map_blocks(
3061+
self,
3062+
func: "Callable[..., T_DSorDA]",
3063+
args: Sequence[Any] = (),
3064+
kwargs: Mapping[str, Any] = None,
3065+
) -> "T_DSorDA":
3066+
"""
3067+
Apply a function to each chunk of this DataArray. This method is experimental
3068+
and its signature may change.
3069+
3070+
Parameters
3071+
----------
3072+
func: callable
3073+
User-provided function that accepts a DataArray as its first parameter. The
3074+
function will receive a subset of this DataArray, corresponding to one chunk
3075+
along each chunked dimension. ``func`` will be executed as
3076+
``func(obj_subset, *args, **kwargs)``.
3077+
3078+
The function will be first run on mocked-up data, that looks like this array
3079+
but has sizes 0, to determine properties of the returned object such as
3080+
dtype, variable names, new dimensions and new indexes (if any).
3081+
3082+
This function must return either a single DataArray or a single Dataset.
3083+
3084+
This function cannot change size of existing dimensions, or add new chunked
3085+
dimensions.
3086+
args: Sequence
3087+
Passed verbatim to func after unpacking, after the sliced DataArray. xarray
3088+
objects, if any, will not be split by chunks. Passing dask collections is
3089+
not allowed.
3090+
kwargs: Mapping
3091+
Passed verbatim to func after unpacking. xarray objects, if any, will not be
3092+
split by chunks. Passing dask collections is not allowed.
3093+
3094+
Returns
3095+
-------
3096+
A single DataArray or Dataset with dask backend, reassembled from the outputs of
3097+
the function.
3098+
3099+
Notes
3100+
-----
3101+
This method is designed for when one needs to manipulate a whole xarray object
3102+
within each chunk. In the more common case where one can work on numpy arrays,
3103+
it is recommended to use apply_ufunc.
3104+
3105+
If none of the variables in this DataArray is backed by dask, calling this
3106+
method is equivalent to calling ``func(self, *args, **kwargs)``.
3107+
3108+
See Also
3109+
--------
3110+
dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks,
3111+
xarray.Dataset.map_blocks
3112+
"""
3113+
from .parallel import map_blocks
3114+
3115+
return map_blocks(func, self, args, kwargs)
3116+
30413117
# this needs to be at the end, or mypy will confuse with `str`
30423118
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
30433119
str = property(StringAccessor)

0 commit comments

Comments
 (0)