Skip to content

Commit d1f7cb8

Browse files
authored
Improve interp performance (#4069)
* Fixes 2223 * more tests * add @requires_scipy to test * fix tests * black * update whatsnew. Added a test for nearest
1 parent 1de38bc commit d1f7cb8

File tree

4 files changed

+40
-7
lines changed

4 files changed

+40
-7
lines changed

doc/whats-new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ Breaking changes
3434
(:pull:`3274`)
3535
By `Elliott Sales de Andrade <https://github.com/QuLogic>`_
3636

37+
Enhancements
38+
~~~~~~~~~~~~
39+
- Performance improvement of :py:meth:`DataArray.interp` and :py:func:`Dataset.interp`
40+
For orthogonal linear- and nearest-neighbor interpolation, we do 1d-interpolation sequentially
41+
rather than interpolating in multidimensional space. (:issue:`2223`)
42+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
43+
3744
New Features
3845
~~~~~~~~~~~~
3946

xarray/core/missing.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,19 @@ def interp(var, indexes_coords, method, **kwargs):
619619
# default behavior
620620
kwargs["bounds_error"] = kwargs.get("bounds_error", False)
621621

622+
# check if the interpolation can be done in orthogonal manner
623+
if (
624+
len(indexes_coords) > 1
625+
and method in ["linear", "nearest"]
626+
and all(dest[1].ndim == 1 for dest in indexes_coords.values())
627+
and len(set([d[1].dims[0] for d in indexes_coords.values()]))
628+
== len(indexes_coords)
629+
):
630+
# interpolate sequentially
631+
for dim, dest in indexes_coords.items():
632+
var = interp(var, {dim: dest}, method, **kwargs)
633+
return var
634+
622635
# target dimensions
623636
dims = list(indexes_coords)
624637
x, new_x = zip(*[indexes_coords[d] for d in dims])
@@ -659,7 +672,7 @@ def interp_func(var, x, new_x, method, kwargs):
659672
New coordinates. Should not contain NaN.
660673
method: string
661674
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for
662-
1-dimensional itnterpolation.
675+
1-dimensional interpolation.
663676
{'linear', 'nearest'} for multidimensional interpolation
664677
**kwargs:
665678
Optional keyword arguments to be passed to scipy.interpolator

xarray/testing.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,7 @@
1010
from xarray.core.indexes import default_indexes
1111
from xarray.core.variable import IndexVariable, Variable
1212

13-
__all__ = (
14-
"assert_allclose",
15-
"assert_chunks_equal",
16-
"assert_equal",
17-
"assert_identical",
18-
)
13+
__all__ = ("assert_allclose", "assert_chunks_equal", "assert_equal", "assert_identical")
1914

2015

2116
def _decode_string_data(data):

xarray/tests/test_interp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,21 @@ def test_3641():
699699
times = xr.cftime_range("0001", periods=3, freq="500Y")
700700
da = xr.DataArray(range(3), dims=["time"], coords=[times])
701701
da.interp(time=["0002-05-01"])
702+
703+
704+
@requires_scipy
705+
@pytest.mark.parametrize("method", ["nearest", "linear"])
706+
def test_decompose(method):
707+
da = xr.DataArray(
708+
np.arange(6).reshape(3, 2),
709+
dims=["x", "y"],
710+
coords={"x": [0, 1, 2], "y": [-0.1, -0.3]},
711+
)
712+
x_new = xr.DataArray([0.5, 1.5, 2.5], dims=["x1"])
713+
y_new = xr.DataArray([-0.15, -0.25], dims=["y1"])
714+
x_broadcast, y_broadcast = xr.broadcast(x_new, y_new)
715+
assert x_broadcast.ndim == 2
716+
717+
actual = da.interp(x=x_new, y=y_new, method=method).drop(("x", "y"))
718+
expected = da.interp(x=x_broadcast, y=y_broadcast, method=method).drop(("x", "y"))
719+
assert_allclose(actual, expected)

0 commit comments

Comments
 (0)