diff --git a/CHANGELOG.md b/CHANGELOG.md index 50c6bcd926..615a99be5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +* Improved error in constructors `tensor.full` and `tensor.full_like` when provided a non-numeric fill value [gh-1878](https://github.com/IntelPython/dpctl/pull/1878) + ### Maintenance * Update black version used in Python code style workflow [gh-1828](https://github.com/IntelPython/dpctl/pull/1828) diff --git a/dpctl/tensor/_ctors.py b/dpctl/tensor/_ctors.py index d3d8fa64f5..37236cad6b 100644 --- a/dpctl/tensor/_ctors.py +++ b/dpctl/tensor/_ctors.py @@ -15,6 +15,7 @@ # limitations under the License. import operator +from numbers import Number import numpy as np @@ -1037,6 +1038,19 @@ def _cast_fill_val(fill_val, dt): return fill_val +def _validate_fill_value(fill_val): + """ + Validates that `fill_val` is a numeric or boolean scalar. + """ + # TODO: verify if `np.True_` and `np.False_` should be instances of + # Number in NumPy, like other NumPy scalars and like Python bools + # check for `np.bool_` separately as NumPy<2 has no `np.bool` + if not isinstance(fill_val, Number) and not isinstance(fill_val, np.bool_): + raise TypeError( + f"array cannot be filled with scalar of type {type(fill_val)}" + ) + + def full( shape, fill_value, @@ -1110,6 +1124,8 @@ def full( sycl_queue=sycl_queue, ) return dpt.copy(dpt.broadcast_to(X, shape), order=order) + else: + _validate_fill_value(fill_value) sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device) usm_type = usm_type if usm_type is not None else "device" @@ -1480,6 +1496,8 @@ def full_like( ) _manager.add_event_pair(hev, copy_ev) return res + else: + _validate_fill_value(fill_value) dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value)) res = _empty_like_orderK(x, dtype, usm_type, sycl_queue) diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 046a2d9496..777a46f090 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -2621,3 +2621,14 @@ def test_setitem_from_numpy_contig(): expected = dpt.reshape(dpt.arange(-10, 10, dtype=fp_dt), (4, 5)) assert dpt.all(dpt.flip(Xdpt, axis=-1) == expected) + + +def test_full_functions_raise_type_error(): + get_queue_or_skip() + + with pytest.raises(TypeError): + dpt.full(1, "0") + + x = dpt.ones(1, dtype="i4") + with pytest.raises(TypeError): + dpt.full_like(x, "0")