Skip to content

Commit e05fdde

Browse files
Recreate @gajomi's #2070 to keep attrs when calling astype() (#4314)
* Recreate @gajomi's #2070 to keep attrs when calling astype() * Accept keep_attrs flag, use apply_ufunc in variable.py, expand tests * Ignore casting if we get a TypeError (indicates sparse <= 0.10.0) * add to whats new * fix dtype test * improve error handling and check if sparse array * test and docstring updates * catch sparse error a bit more elegantly * check using sparse_array_type Co-authored-by: Keewis <[email protected]>
1 parent 68d0e0d commit e05fdde

File tree

7 files changed

+143
-3
lines changed

7 files changed

+143
-3
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ v0.16.1 (unreleased)
2121

2222
Breaking changes
2323
~~~~~~~~~~~~~~~~
24-
24+
- :py:meth:`DataArray.astype` and :py:meth:`Dataset.astype` now preserve attributes. Keep the
25+
old behavior by passing `keep_attrs=False` (:issue:`2049`, :pull:`4314`).
26+
By `Dan Nowacki <https://github.com/dnowacki-usgs>`_ and `Gabriel Joel Mitchell <https://github.com/gajomi>`_.
2527

2628
New Features
2729
~~~~~~~~~~~~

xarray/core/common.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,53 @@ def isin(self, test_elements):
12991299
dask="allowed",
13001300
)
13011301

1302+
def astype(self, dtype, casting="unsafe", copy=True, keep_attrs=True):
1303+
"""
1304+
Copy of the xarray object, with data cast to a specified type.
1305+
Leaves coordinate dtype unchanged.
1306+
1307+
Parameters
1308+
----------
1309+
dtype : str or dtype
1310+
Typecode or data-type to which the array is cast.
1311+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1312+
Controls what kind of data casting may occur. Defaults to 'unsafe'
1313+
for backwards compatibility.
1314+
1315+
* 'no' means the data types should not be cast at all.
1316+
* 'equiv' means only byte-order changes are allowed.
1317+
* 'safe' means only casts which can preserve values are allowed.
1318+
* 'same_kind' means only safe casts or casts within a kind,
1319+
like float64 to float32, are allowed.
1320+
* 'unsafe' means any data conversions may be done.
1321+
copy : bool, optional
1322+
By default, astype always returns a newly allocated array. If this
1323+
is set to False and the `dtype` requirement is satisfied, the input
1324+
array is returned instead of a copy.
1325+
keep_attrs : bool, optional
1326+
By default, astype keeps attributes. Set to False to remove
1327+
attributes in the returned object.
1328+
1329+
Returns
1330+
-------
1331+
out : same as object
1332+
New object with data cast to the specified type.
1333+
1334+
See also
1335+
--------
1336+
np.ndarray.astype
1337+
dask.array.Array.astype
1338+
"""
1339+
from .computation import apply_ufunc
1340+
1341+
return apply_ufunc(
1342+
duck_array_ops.astype,
1343+
self,
1344+
kwargs=dict(dtype=dtype, casting=casting, copy=copy),
1345+
keep_attrs=keep_attrs,
1346+
dask="allowed",
1347+
)
1348+
13021349
def __enter__(self: T) -> T:
13031350
return self
13041351

xarray/core/duck_array_ops.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
import datetime
88
import inspect
99
import warnings
10+
from distutils.version import LooseVersion
1011
from functools import partial
1112

1213
import numpy as np
1314
import pandas as pd
1415

1516
from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils
1617
from .nputils import nanfirst, nanlast
17-
from .pycompat import cupy_array_type, dask_array_type
18+
from .pycompat import cupy_array_type, dask_array_type, sparse_array_type
1819

1920
try:
2021
import dask.array as dask_array
@@ -150,6 +151,28 @@ def trapz(y, x, axis):
150151
)
151152

152153

154+
def astype(data, **kwargs):
155+
try:
156+
import sparse
157+
except ImportError:
158+
sparse = None
159+
160+
if (
161+
sparse is not None
162+
and isinstance(data, sparse_array_type)
163+
and LooseVersion(sparse.__version__) < LooseVersion("0.11.0")
164+
and "casting" in kwargs
165+
):
166+
warnings.warn(
167+
"The current version of sparse does not support the 'casting' argument. It will be ignored in the call to astype().",
168+
RuntimeWarning,
169+
stacklevel=4,
170+
)
171+
kwargs.pop("casting")
172+
173+
return data.astype(**kwargs)
174+
175+
153176
def asarray(data, xp=np):
154177
return (
155178
data

xarray/core/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
NUMPY_SAME_METHODS = ["item", "searchsorted"]
4343
# methods which don't modify the data shape, so the result should still be
4444
# wrapped in an Variable/DataArray
45-
NUMPY_UNARY_METHODS = ["astype", "argsort", "clip", "conj", "conjugate"]
45+
NUMPY_UNARY_METHODS = ["argsort", "clip", "conj", "conjugate"]
4646
PANDAS_UNARY_FUNCTIONS = ["isnull", "notnull"]
4747
# methods which remove an axis
4848
REDUCE_METHODS = ["all", "any"]

xarray/core/variable.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,52 @@ def data(self, data):
360360
)
361361
self._data = data
362362

363+
def astype(self, dtype, casting="unsafe", copy=True, keep_attrs=True):
364+
"""
365+
Copy of the Variable object, with data cast to a specified type.
366+
367+
Parameters
368+
----------
369+
dtype : str or dtype
370+
Typecode or data-type to which the array is cast.
371+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
372+
Controls what kind of data casting may occur. Defaults to 'unsafe'
373+
for backwards compatibility.
374+
375+
* 'no' means the data types should not be cast at all.
376+
* 'equiv' means only byte-order changes are allowed.
377+
* 'safe' means only casts which can preserve values are allowed.
378+
* 'same_kind' means only safe casts or casts within a kind,
379+
like float64 to float32, are allowed.
380+
* 'unsafe' means any data conversions may be done.
381+
copy : bool, optional
382+
By default, astype always returns a newly allocated array. If this
383+
is set to False and the `dtype` requirement is satisfied, the input
384+
array is returned instead of a copy.
385+
keep_attrs : bool, optional
386+
By default, astype keeps attributes. Set to False to remove
387+
attributes in the returned object.
388+
389+
Returns
390+
-------
391+
out : same as object
392+
New object with data cast to the specified type.
393+
394+
See also
395+
--------
396+
np.ndarray.astype
397+
dask.array.Array.astype
398+
"""
399+
from .computation import apply_ufunc
400+
401+
return apply_ufunc(
402+
duck_array_ops.astype,
403+
self,
404+
kwargs=dict(dtype=dtype, casting=casting, copy=copy),
405+
keep_attrs=keep_attrs,
406+
dask="allowed",
407+
)
408+
363409
def load(self, **kwargs):
364410
"""Manually trigger loading of this variable's data from disk or a
365411
remote source into memory and return this variable.

xarray/tests/test_dataarray.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,6 +1874,19 @@ def test_array_interface(self):
18741874
bar = Variable(["x", "y"], np.zeros((10, 20)))
18751875
assert_equal(self.dv, np.maximum(self.dv, bar))
18761876

1877+
def test_astype_attrs(self):
1878+
for v in [self.va.copy(), self.mda.copy(), self.ds.copy()]:
1879+
v.attrs["foo"] = "bar"
1880+
assert v.attrs == v.astype(float).attrs
1881+
assert not v.astype(float, keep_attrs=False).attrs
1882+
1883+
def test_astype_dtype(self):
1884+
original = DataArray([-1, 1, 2, 3, 1000])
1885+
converted = original.astype(float)
1886+
assert_array_equal(original, converted)
1887+
assert np.issubdtype(original.dtype, np.integer)
1888+
assert np.issubdtype(converted.dtype, np.floating)
1889+
18771890
def test_is_null(self):
18781891
x = np.random.RandomState(42).randn(5, 6)
18791892
x[x < 0] = np.nan

xarray/tests/test_dataset.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5634,6 +5634,15 @@ def test_pad(self):
56345634
np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42)
56355635
np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan)
56365636

5637+
def test_astype_attrs(self):
5638+
data = create_test_data(seed=123)
5639+
data.attrs["foo"] = "bar"
5640+
5641+
assert data.attrs == data.astype(float).attrs
5642+
assert data.var1.attrs == data.astype(float).var1.attrs
5643+
assert not data.astype(float, keep_attrs=False).attrs
5644+
assert not data.astype(float, keep_attrs=False).var1.attrs
5645+
56375646

56385647
# Py.test tests
56395648

0 commit comments

Comments
 (0)