Skip to content

Commit 75695ce

Browse files
Add a test for dlpack with dpt
1 parent 03450eb commit 75695ce

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

tests/test_dparray.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ def test_from_dlpack(arr_dtype,shape):
4141
W = dpnp.from_dlpack(V)
4242
assert V.strides == W.strides
4343

44+
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
45+
def test_from_dlpack_with_dpt(arr_dtype):
46+
X = dpt.empty((64,),dtype=arr_dtype)
47+
Y = dpnp.from_dlpack(X)
48+
assert_array_equal(X, Y)
49+
assert isinstance(Y, dpnp.dpnp_array.dpnp_array)
50+
assert X.__dlpack_device__() == Y.__dlpack_device__()
51+
assert X.shape == Y.shape
52+
assert X.dtype == Y.dtype or (
53+
str(X.dtype) == "bool" and str(Y.dtype) == "uint8"
54+
)
55+
assert X.sycl_device == Y.sycl_device
56+
assert X.usm_type == Y.usm_type
57+
4458

4559
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
4660
@pytest.mark.parametrize("arr",

0 commit comments

Comments
 (0)