Skip to content

Commit 1d3dee0

Browse files
authored
nd-rolling (#4219)
* nd-rolling * remove unnecessary print * black * finding a bug... * make tests for ndrolling pass * make center and window_dim a dict * A cleanup. * Revert test_units * make test pass * More tests. * more docs * mypy * improve whatsnew * Improve doc * Support nd-rolling in dask correctly * Cleanup according to max's comment * flake8 * black * stop using either_dict_or_kwargs * Better tests. * typo * mypy * typo2
1 parent e04e21d commit 1d3dee0

File tree

12 files changed

+362
-133
lines changed

12 files changed

+362
-133
lines changed

doc/computation.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,16 @@ a value when aggregating:
188188
r = arr.rolling(y=3, center=True, min_periods=2)
189189
r.mean()
190190
191+
From version 0.17, xarray supports multidimensional rolling,
192+
193+
.. ipython:: python
194+
195+
r = arr.rolling(x=2, y=3, min_periods=2)
196+
r.mean()
197+
191198
.. tip::
192199

193-
Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects.
200+
Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects with 1d-rolling.
194201

195202
.. _bottleneck: https://github.com/pydata/bottleneck/
196203

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ Breaking changes
2525

2626
New Features
2727
~~~~~~~~~~~~
28+
- :py:meth:`~xarray.DataArray.rolling` and :py:meth:`~xarray.Dataset.rolling`
29+
now accept more than 1 dimension.(:pull:`4219`)
30+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
2831
- Build :py:meth:`CFTimeIndex.__repr__` explicitly as :py:class:`pandas.Index`. Add ``calendar`` as a new
2932
property for :py:class:`CFTimeIndex` and show ``calendar`` and ``length`` in
3033
:py:meth:`CFTimeIndex.__repr__` (:issue:`2416`, :pull:`4092`)

xarray/core/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def rolling(
786786
self,
787787
dim: Mapping[Hashable, int] = None,
788788
min_periods: int = None,
789-
center: bool = False,
789+
center: Union[bool, Mapping[Hashable, bool]] = False,
790790
keep_attrs: bool = None,
791791
**window_kwargs: int,
792792
):
@@ -802,7 +802,7 @@ def rolling(
802802
Minimum number of observations in window required to have a value
803803
(otherwise result is NA). The default, None, is equivalent to
804804
setting min_periods equal to the size of the window.
805-
center : boolean, default False
805+
center : boolean, or a mapping, default False
806806
Set the labels at the center of the window.
807807
keep_attrs : bool, optional
808808
If True, the object's attributes (`attrs`) will be copied from

xarray/core/dask_array_ops.py

Lines changed: 62 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -32,69 +32,80 @@ def rolling_window(a, axis, window, center, fill_value):
3232
"""
3333
import dask.array as da
3434

35+
if not hasattr(axis, "__len__"):
36+
axis = [axis]
37+
window = [window]
38+
center = [center]
39+
3540
orig_shape = a.shape
36-
if axis < 0:
37-
axis = a.ndim + axis
3841
depth = {d: 0 for d in range(a.ndim)}
39-
depth[axis] = int(window / 2)
40-
# For evenly sized window, we need to crop the first point of each block.
41-
offset = 1 if window % 2 == 0 else 0
42-
43-
if depth[axis] > min(a.chunks[axis]):
44-
raise ValueError(
45-
"For window size %d, every chunk should be larger than %d, "
46-
"but the smallest chunk size is %d. Rechunk your array\n"
47-
"with a larger chunk size or a chunk size that\n"
48-
"more evenly divides the shape of your array."
49-
% (window, depth[axis], min(a.chunks[axis]))
50-
)
51-
52-
# Although da.overlap pads values to boundaries of the array,
53-
# the size of the generated array is smaller than what we want
54-
# if center == False.
55-
if center:
56-
start = int(window / 2) # 10 -> 5, 9 -> 4
57-
end = window - 1 - start
58-
else:
59-
start, end = window - 1, 0
60-
pad_size = max(start, end) + offset - depth[axis]
61-
drop_size = 0
62-
# pad_size becomes more than 0 when the overlapped array is smaller than
63-
# needed. In this case, we need to enlarge the original array by padding
64-
# before overlapping.
65-
if pad_size > 0:
66-
if pad_size < depth[axis]:
67-
# overlapping requires each chunk larger than depth. If pad_size is
68-
# smaller than the depth, we enlarge this and truncate it later.
69-
drop_size = depth[axis] - pad_size
70-
pad_size = depth[axis]
71-
shape = list(a.shape)
72-
shape[axis] = pad_size
73-
chunks = list(a.chunks)
74-
chunks[axis] = (pad_size,)
75-
fill_array = da.full(shape, fill_value, dtype=a.dtype, chunks=chunks)
76-
a = da.concatenate([fill_array, a], axis=axis)
77-
42+
offset = [0] * a.ndim
43+
drop_size = [0] * a.ndim
44+
pad_size = [0] * a.ndim
45+
for ax, win, cent in zip(axis, window, center):
46+
if ax < 0:
47+
ax = a.ndim + ax
48+
depth[ax] = int(win / 2)
49+
# For evenly sized window, we need to crop the first point of each block.
50+
offset[ax] = 1 if win % 2 == 0 else 0
51+
52+
if depth[ax] > min(a.chunks[ax]):
53+
raise ValueError(
54+
"For window size %d, every chunk should be larger than %d, "
55+
"but the smallest chunk size is %d. Rechunk your array\n"
56+
"with a larger chunk size or a chunk size that\n"
57+
"more evenly divides the shape of your array."
58+
% (win, depth[ax], min(a.chunks[ax]))
59+
)
60+
61+
# Although da.overlap pads values to boundaries of the array,
62+
# the size of the generated array is smaller than what we want
63+
# if center == False.
64+
if cent:
65+
start = int(win / 2) # 10 -> 5, 9 -> 4
66+
end = win - 1 - start
67+
else:
68+
start, end = win - 1, 0
69+
pad_size[ax] = max(start, end) + offset[ax] - depth[ax]
70+
drop_size[ax] = 0
71+
# pad_size becomes more than 0 when the overlapped array is smaller than
72+
# needed. In this case, we need to enlarge the original array by padding
73+
# before overlapping.
74+
if pad_size[ax] > 0:
75+
if pad_size[ax] < depth[ax]:
76+
# overlapping requires each chunk larger than depth. If pad_size is
77+
# smaller than the depth, we enlarge this and truncate it later.
78+
drop_size[ax] = depth[ax] - pad_size[ax]
79+
pad_size[ax] = depth[ax]
80+
81+
# TODO maybe following two lines can be summarized.
82+
a = da.pad(
83+
a, [(p, 0) for p in pad_size], mode="constant", constant_values=fill_value
84+
)
7885
boundary = {d: fill_value for d in range(a.ndim)}
7986

8087
# create overlap arrays
8188
ag = da.overlap.overlap(a, depth=depth, boundary=boundary)
8289

83-
# apply rolling func
84-
def func(x, window, axis=-1):
90+
def func(x, window, axis):
8591
x = np.asarray(x)
86-
rolling = nputils._rolling_window(x, window, axis)
87-
return rolling[(slice(None),) * axis + (slice(offset, None),)]
88-
89-
chunks = list(a.chunks)
90-
chunks.append(window)
92+
index = [slice(None)] * x.ndim
93+
for ax, win in zip(axis, window):
94+
x = nputils._rolling_window(x, win, ax)
95+
index[ax] = slice(offset[ax], None)
96+
return x[tuple(index)]
97+
98+
chunks = list(a.chunks) + window
99+
new_axis = [a.ndim + i for i in range(len(axis))]
91100
out = ag.map_blocks(
92-
func, dtype=a.dtype, new_axis=a.ndim, chunks=chunks, window=window, axis=axis
101+
func, dtype=a.dtype, new_axis=new_axis, chunks=chunks, window=window, axis=axis
93102
)
94103

95104
# crop boundary.
96-
index = (slice(None),) * axis + (slice(drop_size, drop_size + orig_shape[axis]),)
97-
return out[index]
105+
index = [slice(None)] * a.ndim
106+
for ax in axis:
107+
index[ax] = slice(drop_size[ax], drop_size[ax] + orig_shape[ax])
108+
return out[tuple(index)]
98109

99110

100111
def least_squares(lhs, rhs, rcond=None, skipna=False):

xarray/core/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5968,7 +5968,7 @@ def polyfit(
59685968
skipna_da = np.any(da.isnull())
59695969

59705970
dims_to_stack = [dimname for dimname in da.dims if dimname != dim]
5971-
stacked_coords = {}
5971+
stacked_coords: Dict[Hashable, DataArray] = {}
59725972
if dims_to_stack:
59735973
stacked_dim = utils.get_temp_dimname(dims_to_stack, "stacked")
59745974
rhs = da.transpose(dim, *dims_to_stack).stack(

xarray/core/nputils.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,22 @@ def __setitem__(self, key, value):
135135
def rolling_window(a, axis, window, center, fill_value):
136136
""" rolling window with padding. """
137137
pads = [(0, 0) for s in a.shape]
138-
if center:
139-
start = int(window / 2) # 10 -> 5, 9 -> 4
140-
end = window - 1 - start
141-
pads[axis] = (start, end)
142-
else:
143-
pads[axis] = (window - 1, 0)
138+
if not hasattr(axis, "__len__"):
139+
axis = [axis]
140+
window = [window]
141+
center = [center]
142+
143+
for ax, win, cent in zip(axis, window, center):
144+
if cent:
145+
start = int(win / 2) # 10 -> 5, 9 -> 4
146+
end = win - 1 - start
147+
pads[ax] = (start, end)
148+
else:
149+
pads[ax] = (win - 1, 0)
144150
a = np.pad(a, pads, mode="constant", constant_values=fill_value)
145-
return _rolling_window(a, window, axis)
151+
for ax, win in zip(axis, window):
152+
a = _rolling_window(a, win, ax)
153+
return a
146154

147155

148156
def _rolling_window(a, window, axis=-1):

0 commit comments

Comments
 (0)