Skip to content

Commit 47476eb

Browse files
authored
Pint support for variables (#3706)
* get fillna tests to pass * get the _getitem_with_mask tests to pass * silence the behavior change warning of pint * don't use 0 as fill value since that has special behaviour * use concat as a class method * use np.pad after trimming instead of concatenating a filled array * rewrite the concat test to pass appropriate arrays * use da.pad when dealing with dask arrays * mark the failing pad tests as xfail when on a current pint version * update whats-new.rst * fix the import order * test using pint master * fix the install command * reimplement the pad test to really work with units * use np.logical_not instead * use duck_array_ops to provide pad * add comments explaining the order of the arguments to where * mark the flipped parameter changes with a todo * skip the identical tests * remove the warnings filter
1 parent 1667e4c commit 47476eb

File tree

5 files changed

+89
-63
lines changed

5 files changed

+89
-63
lines changed

ci/azure/install.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ steps:
2929
git+https://github.com/zarr-developers/zarr \
3030
git+https://github.com/Unidata/cftime \
3131
git+https://github.com/mapbox/rasterio \
32+
git+https://github.com/hgrecco/pint \
3233
git+https://github.com/pydata/bottleneck
3334
condition: eq(variables['UPSTREAM_DEV'], 'true')
3435
displayName: Install upstream dev dependencies

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ Breaking changes
2424

2525
New Features
2626
~~~~~~~~~~~~
27+
- implement pint support. (:issue:`3594`, :pull:`3706`)
28+
By `Justus Magin <https://github.com/keewis>`_.
2729

2830
Bug fixes
2931
~~~~~~~~~

xarray/core/duck_array_ops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def notnull(data):
121121
isin = _dask_or_eager_func("isin", array_args=slice(2))
122122
take = _dask_or_eager_func("take")
123123
broadcast_to = _dask_or_eager_func("broadcast_to")
124+
pad = _dask_or_eager_func("pad")
124125

125126
_concatenate = _dask_or_eager_func("concatenate", list_of_args=True)
126127
_stack = _dask_or_eager_func("stack", list_of_args=True)
@@ -261,7 +262,10 @@ def where_method(data, cond, other=dtypes.NA):
261262

262263

263264
def fillna(data, other):
264-
return where(isnull(data), other, data)
265+
# we need to pass data first so pint has a chance of returning the
266+
# correct unit
267+
# TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed
268+
return where(notnull(data), data, other)
265269

266270

267271
def concatenate(arrays, axis=0):

xarray/core/variable.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,10 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA):
742742

743743
data = as_indexable(self._data)[actual_indexer]
744744
mask = indexing.create_mask(indexer, self.shape, data)
745-
data = duck_array_ops.where(mask, fill_value, data)
745+
# we need to invert the mask in order to pass data first. This helps
746+
# pint to choose the correct unit
747+
# TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed
748+
data = duck_array_ops.where(np.logical_not(mask), data, fill_value)
746749
else:
747750
# array cannot be indexed along dimensions of size 0, so just
748751
# build the mask directly instead.
@@ -1099,24 +1102,16 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA):
10991102
else:
11001103
dtype = self.dtype
11011104

1102-
shape = list(self.shape)
1103-
shape[axis] = min(abs(count), shape[axis])
1105+
width = min(abs(count), self.shape[axis])
1106+
dim_pad = (width, 0) if count >= 0 else (0, width)
1107+
pads = [(0, 0) if d != dim else dim_pad for d in self.dims]
11041108

1105-
if isinstance(trimmed_data, dask_array_type):
1106-
chunks = list(trimmed_data.chunks)
1107-
chunks[axis] = (shape[axis],)
1108-
full = functools.partial(da.full, chunks=chunks)
1109-
else:
1110-
full = np.full
1111-
1112-
filler = full(shape, fill_value, dtype=dtype)
1113-
1114-
if count > 0:
1115-
arrays = [filler, trimmed_data]
1116-
else:
1117-
arrays = [trimmed_data, filler]
1118-
1119-
data = duck_array_ops.concatenate(arrays, axis)
1109+
data = duck_array_ops.pad(
1110+
trimmed_data.astype(dtype),
1111+
pads,
1112+
mode="constant",
1113+
constant_values=fill_value,
1114+
)
11201115

11211116
if isinstance(data, dask_array_type):
11221117
# chunked data should come out with the same chunks; this makes

xarray/tests/test_units.py

Lines changed: 68 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import operator
2+
from distutils.version import LooseVersion
23

34
import numpy as np
45
import pandas as pd
@@ -19,6 +20,7 @@
1920
unit_registry = pint.UnitRegistry(force_ndarray=True)
2021
Quantity = unit_registry.Quantity
2122

23+
2224
pytestmark = [
2325
pytest.mark.skipif(
2426
not IS_NEP18_ACTIVE, reason="NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled"
@@ -1536,27 +1538,17 @@ def test_missing_value_detection(self, func):
15361538
@pytest.mark.parametrize(
15371539
"unit,error",
15381540
(
1539-
pytest.param(
1540-
1,
1541-
DimensionalityError,
1542-
id="no_unit",
1543-
marks=pytest.mark.xfail(reason="uses 0 as a replacement"),
1544-
),
1541+
pytest.param(1, DimensionalityError, id="no_unit"),
15451542
pytest.param(
15461543
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
15471544
),
15481545
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
1549-
pytest.param(
1550-
unit_registry.cm,
1551-
None,
1552-
id="compatible_unit",
1553-
marks=pytest.mark.xfail(reason="converts to fill value's unit"),
1554-
),
1546+
pytest.param(unit_registry.cm, None, id="compatible_unit"),
15551547
pytest.param(unit_registry.m, None, id="identical_unit"),
15561548
),
15571549
)
15581550
def test_missing_value_fillna(self, unit, error):
1559-
value = 0
1551+
value = 10
15601552
array = (
15611553
np.array(
15621554
[
@@ -1595,13 +1587,7 @@ def test_missing_value_fillna(self, unit, error):
15951587
pytest.param(1, id="no_unit"),
15961588
pytest.param(unit_registry.dimensionless, id="dimensionless"),
15971589
pytest.param(unit_registry.s, id="incompatible_unit"),
1598-
pytest.param(
1599-
unit_registry.cm,
1600-
id="compatible_unit",
1601-
marks=pytest.mark.xfail(
1602-
reason="checking for identical units does not work properly, yet"
1603-
),
1604-
),
1590+
pytest.param(unit_registry.cm, id="compatible_unit",),
16051591
pytest.param(unit_registry.m, id="identical_unit"),
16061592
),
16071593
)
@@ -1612,7 +1598,17 @@ def test_missing_value_fillna(self, unit, error):
16121598
pytest.param(True, id="with_conversion"),
16131599
),
16141600
)
1615-
@pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr)
1601+
@pytest.mark.parametrize(
1602+
"func",
1603+
(
1604+
method("equals"),
1605+
pytest.param(
1606+
method("identical"),
1607+
marks=pytest.mark.skip(reason="behaviour of identical is unclear"),
1608+
),
1609+
),
1610+
ids=repr,
1611+
)
16161612
def test_comparisons(self, func, unit, convert_data, dtype):
16171613
array = np.linspace(0, 1, 9).astype(dtype)
16181614
quantity1 = array * unit_registry.m
@@ -1762,14 +1758,7 @@ def test_1d_math(self, func, unit, error, dtype):
17621758
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
17631759
),
17641760
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
1765-
pytest.param(
1766-
unit_registry.cm,
1767-
None,
1768-
id="compatible_unit",
1769-
marks=pytest.mark.xfail(
1770-
reason="getitem_with_mask converts to the unit of other"
1771-
),
1772-
),
1761+
pytest.param(unit_registry.cm, None, id="compatible_unit"),
17731762
pytest.param(unit_registry.m, None, id="identical_unit"),
17741763
),
17751764
)
@@ -1853,12 +1842,7 @@ def test_squeeze(self, dtype):
18531842
),
18541843
method("reduce", np.std, "x"),
18551844
method("round", 2),
1856-
pytest.param(
1857-
method("shift", {"x": -2}),
1858-
marks=pytest.mark.xfail(
1859-
reason="trying to concatenate ndarray to quantity"
1860-
),
1861-
),
1845+
method("shift", {"x": -2}),
18621846
method("transpose", "y", "x"),
18631847
),
18641848
ids=repr,
@@ -1933,7 +1917,6 @@ def test_unstack(self, dtype):
19331917
assert_units_equal(expected, actual)
19341918
xr.testing.assert_identical(expected, actual)
19351919

1936-
@pytest.mark.xfail(reason="ignores units")
19371920
@pytest.mark.parametrize(
19381921
"unit,error",
19391922
(
@@ -1948,25 +1931,28 @@ def test_unstack(self, dtype):
19481931
)
19491932
def test_concat(self, unit, error, dtype):
19501933
array1 = (
1951-
np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m
1934+
np.linspace(0, 5, 9 * 10).reshape(3, 6, 5).astype(dtype) * unit_registry.m
19521935
)
1953-
array2 = np.linspace(5, 10, 10 * 2).reshape(10, 2).astype(dtype) * unit
1936+
array2 = np.linspace(5, 10, 10 * 3).reshape(3, 2, 5).astype(dtype) * unit
19541937

1955-
variable = xr.Variable(("x", "y"), array1)
1956-
other = xr.Variable(("y", "z"), array2)
1938+
variable = xr.Variable(("x", "y", "z"), array1)
1939+
other = xr.Variable(("x", "y", "z"), array2)
19571940

19581941
if error is not None:
19591942
with pytest.raises(error):
1960-
variable.concat(other)
1943+
xr.Variable.concat([variable, other], dim="y")
19611944

19621945
return
19631946

19641947
units = extract_units(variable)
19651948
expected = attach_units(
1966-
strip_units(variable).concat(strip_units(convert_units(other, units))),
1949+
xr.Variable.concat(
1950+
[strip_units(variable), strip_units(convert_units(other, units))],
1951+
dim="y",
1952+
),
19671953
units,
19681954
)
1969-
actual = variable.concat(other)
1955+
actual = xr.Variable.concat([variable, other], dim="y")
19701956

19711957
assert_units_equal(expected, actual)
19721958
xr.testing.assert_identical(expected, actual)
@@ -2036,6 +2022,43 @@ def test_no_conflicts(self, unit, dtype):
20362022

20372023
assert expected == actual
20382024

2025+
def test_pad(self, dtype):
2026+
data = np.arange(4 * 3 * 2).reshape(4, 3, 2).astype(dtype) * unit_registry.m
2027+
v = xr.Variable(["x", "y", "z"], data)
2028+
2029+
xr_args = [{"x": (2, 1)}, {"y": (0, 3)}, {"x": (3, 1), "z": (2, 0)}]
2030+
np_args = [
2031+
((2, 1), (0, 0), (0, 0)),
2032+
((0, 0), (0, 3), (0, 0)),
2033+
((3, 1), (0, 0), (2, 0)),
2034+
]
2035+
for xr_arg, np_arg in zip(xr_args, np_args):
2036+
actual = v.pad_with_fill_value(**xr_arg)
2037+
expected = xr.Variable(
2038+
v.dims,
2039+
np.pad(
2040+
v.data.astype(float),
2041+
np_arg,
2042+
mode="constant",
2043+
constant_values=np.nan,
2044+
),
2045+
)
2046+
xr.testing.assert_identical(expected, actual)
2047+
assert_units_equal(expected, actual)
2048+
assert isinstance(actual._data, type(v._data))
2049+
2050+
# for the boolean array, we pad False
2051+
data = np.full_like(data, False, dtype=bool).reshape(4, 3, 2)
2052+
v = xr.Variable(["x", "y", "z"], data)
2053+
for xr_arg, np_arg in zip(xr_args, np_args):
2054+
actual = v.pad_with_fill_value(fill_value=data.flat[0], **xr_arg)
2055+
expected = xr.Variable(
2056+
v.dims,
2057+
np.pad(v.data, np_arg, mode="constant", constant_values=v.data.flat[0]),
2058+
)
2059+
xr.testing.assert_identical(actual, expected)
2060+
assert_units_equal(expected, actual)
2061+
20392062
@pytest.mark.parametrize(
20402063
"unit,error",
20412064
(
@@ -2044,7 +2067,8 @@ def test_no_conflicts(self, unit, dtype):
20442067
DimensionalityError,
20452068
id="no_unit",
20462069
marks=pytest.mark.xfail(
2047-
reason="is not treated the same as dimensionless"
2070+
LooseVersion(pint.__version__) < LooseVersion("0.10.2"),
2071+
reason="bug in pint's implementation of np.pad",
20482072
),
20492073
),
20502074
pytest.param(

0 commit comments

Comments
 (0)