diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 86d87ff46c..c220b61b26 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -23,6 +23,7 @@ import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti import dpctl.utils +from dpctl.tensor._ctors import _get_dtype from dpctl.tensor._device import normalize_queue_device __doc__ = ( @@ -364,7 +365,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): array (usm_ndarray): An input array. new_dtype (dtype): - The data type of the resulting array. + The data type of the resulting array. If `None`, gives default + floating point type supported by device where `array` is allocated. order ({"C", "F", "A", "K"}, optional): Controls memory layout of the resulting array if a copy is returned. @@ -392,7 +394,7 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): "Recognized values are 'A', 'C', 'F', or 'K'" ) ary_dtype = usm_ary.dtype - target_dtype = dpt.dtype(newdtype) + target_dtype = _get_dtype(newdtype, usm_ary.sycl_queue) if not dpt.can_cast(ary_dtype, target_dtype, casting=casting): raise TypeError( f"Can not cast from {ary_dtype} to {newdtype} " diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index da2201c75f..67bc162b81 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -1196,6 +1196,11 @@ def test_astype(): assert np.allclose(dpt.to_numpy(Y), np.full(Y.shape, 7, dtype="f4")) Y = dpt.astype(X[::2, ::-1], "i4", order="K", copy=False) assert Y.usm_data is X.usm_data + Y = dpt.astype(X, None, order="K") + if X.sycl_queue.sycl_device.has_aspect_fp64: + assert Y.dtype is dpt.float64 + else: + assert Y.dtype is dpt.float32 def test_astype_invalid_order():