diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index fb8e1fcef12d..fd2d06f74285 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -562,7 +562,15 @@ def asnumpy(self): return dpt.asnumpy(self._array_obj) - def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True): + def astype( + self, + dtype, + order="K", + casting="unsafe", + subok=True, + copy=True, + device=None, + ): """ Copy the array with data type casting. @@ -597,6 +605,13 @@ def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True): this is set to ``False``, and the `dtype`, `order`, and `subok` requirements are satisfied, the input array is returned instead of a copy. + device : {None, string, SyclDevice, SyclQueue}, optional + An array API concept of device where the output array is created. + The `device` can be ``None`` (the default), an OneAPI filter selector + string, an instance of :class:`dpctl.SyclDevice` corresponding to + a non-partitioned SYCL device, an instance of :class:`dpctl.SyclQueue`, + or a `Device` object returned by + :obj:`dpnp.dpnp_array.dpnp_array.device` property. Default: ``None``. Returns ------- @@ -626,7 +641,9 @@ def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True): f"subok={subok} is currently not supported" ) - return dpnp.astype(self, dtype, order=order, casting=casting, copy=copy) + return dpnp.astype( + self, dtype, order=order, casting=casting, copy=copy, device=device + ) # 'base', # 'byteswap', diff --git a/dpnp/dpnp_iface.py b/dpnp/dpnp_iface.py index 0dfd63dab217..49e7b41c01c9 100644 --- a/dpnp/dpnp_iface.py +++ b/dpnp/dpnp_iface.py @@ -180,7 +180,7 @@ def asnumpy(a, order="C"): # pylint: disable=redefined-outer-name -def astype(x1, dtype, order="K", casting="unsafe", copy=True): +def astype(x1, dtype, order="K", casting="unsafe", copy=True, device=None): """ Copy the array with data type casting. @@ -213,6 +213,13 @@ def astype(x1, dtype, order="K", casting="unsafe", copy=True): By default, ``astype`` always returns a newly allocated array. If this is set to ``False``, and the `dtype`, `order`, and `subok` requirements are satisfied, the input array is returned instead of a copy. + device : {None, string, SyclDevice, SyclQueue}, optional + An array API concept of device where the output array is created. + The `device` can be ``None`` (the default), an OneAPI filter selector + string, an instance of :class:`dpctl.SyclDevice` corresponding to + a non-partitioned SYCL device, an instance of :class:`dpctl.SyclQueue`, + or a `Device` object returned by + :obj:`dpnp.dpnp_array.dpnp_array.device` property. Default: ``None``. Returns ------- @@ -228,7 +235,7 @@ def astype(x1, dtype, order="K", casting="unsafe", copy=True): x1_obj = dpnp.get_usm_ndarray(x1) array_obj = dpt.astype( - x1_obj, dtype, order=order, casting=casting, copy=copy + x1_obj, dtype, order=order, casting=casting, copy=copy, device=device ) # return x1 if dpctl returns a zero copy of x1_obj diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 3073b8806e5e..99334cfabfcd 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -2211,3 +2211,22 @@ def test_histogram_bin_edges(weights, device): edges_queue = result_edges.sycl_queue assert_sycl_queue_equal(edges_queue, iv.sycl_queue) + + +@pytest.mark.parametrize( + "device_x", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +@pytest.mark.parametrize( + "device_y", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_astype(device_x, device_y): + x = dpnp.array([1, 2, 3], dtype="i4", device=device_x) + y = dpnp.astype(x, dtype="f4") + assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue) + sycl_queue = dpctl.SyclQueue(device_y) + y = dpnp.astype(x, dtype="f4", device=sycl_queue) + assert_sycl_queue_equal(y.sycl_queue, sycl_queue)