From 3d7b1270e53fa02b11fdcf8e734652025ed073ba Mon Sep 17 00:00:00 2001 From: caenrigen <31376402+caenrigen@users.noreply.github.com> Date: Tue, 2 Mar 2021 12:53:07 +0100 Subject: [PATCH 1/6] added support for numpy.bool_ --- xarray/backends/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 4fa34b39925..fe3dcc7cb52 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -218,10 +218,10 @@ def check_attr(name, value): "serialization to netCDF files" ) - if not isinstance(value, (str, Number, np.ndarray, np.number, list, tuple)): + if not isinstance(value, (str, Number, np.ndarray, np.number, np.bool_, list, tuple)): raise TypeError( f"Invalid value for attr {name!r}: {value!r} must be a number, " - "a string, an ndarray or a list/tuple of " + "a string, an ndarray, a bool/numpy.bool_, or a list/tuple of " "numbers/strings for serialization to netCDF " "files" ) From b0f532a7bf746d1fd5d42868bb4df89e259cfdf9 Mon Sep 17 00:00:00 2001 From: caenrigen <31376402+caenrigen@users.noreply.github.com> Date: Tue, 2 Mar 2021 13:20:23 +0100 Subject: [PATCH 2/6] add tests fix black --- xarray/backends/api.py | 4 +++- xarray/tests/test_backends.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index fe3dcc7cb52..9d5589e7232 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -218,7 +218,9 @@ def check_attr(name, value): "serialization to netCDF files" ) - if not isinstance(value, (str, Number, np.ndarray, np.number, np.bool_, list, tuple)): + if not isinstance( + value, (str, Number, np.ndarray, np.number, np.bool_, list, tuple) + ): raise TypeError( f"Invalid value for attr {name!r}: {value!r} must be a number, " "a string, an ndarray, a bool/numpy.bool_, or a list/tuple of " diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index d15736e608d..3e031936e86 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2541,6 +2541,14 @@ def test_complex(self, invalid_netcdf, warntype, num_warns): assert recorded_num_warns == num_warns + def test_numpy_bool_(self): + # h5netcdf loads booleans as numpy.bool_, this type needs to be supported + # when writing invalid_netcdf datasets in order to support a roundtrip + expected = Dataset({"x": ("y", np.ones(5), {"numpy_bool": np.bool_(True)})}) + save_kwargs = {"invalid_netcdf": True} + with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: + assert_equal(expected, actual) + def test_cross_engine_read_write_netcdf4(self): # Drop dim3, because its labels include strings. These appear to be # not properly read with python-netCDF4, which converts them into From 65e54c660bd9caf5ad8f904b367ab4387255e557 Mon Sep 17 00:00:00 2001 From: caenrigen <31376402+caenrigen@users.noreply.github.com> Date: Tue, 2 Mar 2021 22:38:22 +0100 Subject: [PATCH 3/6] improve tests --- xarray/tests/test_backends.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 3e031936e86..f6c00a2a9a9 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2547,7 +2547,7 @@ def test_numpy_bool_(self): expected = Dataset({"x": ("y", np.ones(5), {"numpy_bool": np.bool_(True)})}) save_kwargs = {"invalid_netcdf": True} with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: - assert_equal(expected, actual) + assert_identical(expected, actual) def test_cross_engine_read_write_netcdf4(self): # Drop dim3, because its labels include strings. These appear to be From a70bbfbc36cb8e5e07f238dae9335800aca2d0e0 Mon Sep 17 00:00:00 2001 From: caenrigen <31376402+caenrigen@users.noreply.github.com> Date: Tue, 2 Mar 2021 22:38:40 +0100 Subject: [PATCH 4/6] update docstring --- xarray/backends/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 9d5589e7232..53ddc0a7cdc 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -201,7 +201,7 @@ def check_name(name): def _validate_attrs(dataset): """`attrs` must have a string key and a value which is either: a number, - a string, an ndarray or a list/tuple of numbers/strings. + a string, an ndarray, a bool/numpy.bool_, or a list/tuple of numbers/strings. """ def check_attr(name, value): From 284201afea36366b61d772f26e37c56c079c6837 Mon Sep 17 00:00:00 2001 From: caenrigen <31376402+caenrigen@users.noreply.github.com> Date: Tue, 2 Mar 2021 22:56:29 +0100 Subject: [PATCH 5/6] added whats-new entry --- doc/whats-new.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index eed4e16eb62..9c94c51120c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,7 +34,8 @@ Deprecations Bug fixes ~~~~~~~~~ - +- Added support for `numpy.bool_` attributes in roundtrips using `h5netcdf` engine with `invalid_netcdf=True` [which casts `bool`s to `numpy.bool_`] (:issue:`4981`, :pull:`4986`). + By `Victor Negîrneac `_. Documentation ~~~~~~~~~~~~~ From aa2ff83962e0c1a4355a88e8b97f2c20d862e940 Mon Sep 17 00:00:00 2001 From: caenrigen <31376402+caenrigen@users.noreply.github.com> Date: Mon, 8 Mar 2021 17:19:35 +0100 Subject: [PATCH 6/6] restrict numpy.bool_ to the specific use-case --- xarray/backends/api.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 53ddc0a7cdc..aca6524a3bd 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -199,12 +199,21 @@ def check_name(name): check_name(k) -def _validate_attrs(dataset): +def _validate_attrs(dataset, invalid_netcdf=False): """`attrs` must have a string key and a value which is either: a number, - a string, an ndarray, a bool/numpy.bool_, or a list/tuple of numbers/strings. + a string, an ndarray, a list/tuple of numbers/strings, or a numpy.bool_. + + Notes + ----- + A numpy.bool_ is only allowed when using the h5netcdf engine with + `invalid_netcdf=True`. """ - def check_attr(name, value): + valid_types = (str, Number, np.ndarray, np.number, list, tuple) + if invalid_netcdf: + valid_types += (np.bool_,) + + def check_attr(name, value, valid_types): if isinstance(name, str): if not name: raise ValueError( @@ -218,24 +227,21 @@ def check_attr(name, value): "serialization to netCDF files" ) - if not isinstance( - value, (str, Number, np.ndarray, np.number, np.bool_, list, tuple) - ): + if not isinstance(value, valid_types): raise TypeError( - f"Invalid value for attr {name!r}: {value!r} must be a number, " - "a string, an ndarray, a bool/numpy.bool_, or a list/tuple of " - "numbers/strings for serialization to netCDF " - "files" + f"Invalid value for attr {name!r}: {value!r}. For serialization to " + "netCDF files, its value must be of one of the following types: " + f"{', '.join([vtype.__name__ for vtype in valid_types])}" ) # Check attrs on the dataset itself for k, v in dataset.attrs.items(): - check_attr(k, v) + check_attr(k, v, valid_types) # Check attrs on each variable within the dataset for variable in dataset.variables.values(): for k, v in variable.attrs.items(): - check_attr(k, v) + check_attr(k, v, valid_types) def _protect_dataset_variables_inplace(dataset, cache): @@ -1058,7 +1064,7 @@ def to_netcdf( # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) - _validate_attrs(dataset) + _validate_attrs(dataset, invalid_netcdf=invalid_netcdf and engine == "h5netcdf") try: store_open = WRITEABLE_STORES[engine]