Skip to content

Commit 4c158da

Browse files
Fix remarks, add _create_from_usm_ndarray func and move tests to test_sycl_queue
1 parent 75695ce commit 4c158da

File tree

5 files changed

+61
-42
lines changed

5 files changed

+61
-42
lines changed

dpnp/dpnp_array.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,16 @@ def __truediv__(self, other):
319319

320320
# '__xor__',
321321

322+
@staticmethod
323+
def _create_from_usm_ndarray(usm_ary : dpt.usm_ndarray):
324+
if not isinstance(usm_ary, dpt.usm_ndarray):
325+
raise TypeError(
326+
f"Expected dpctl.tensor.usm_ndarray, got {type(usm_ary)}"
327+
)
328+
res = dpnp_array.__new__(dpnp_array)
329+
res._array_obj = usm_ary
330+
return res
331+
322332
def all(self, axis=None, out=None, keepdims=False):
323333
"""
324334
Returns True if all elements evaluate to True.

dpnp/dpnp_iface.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def default_float_type(device=None, sycl_queue=None):
223223
return map_dtype_to_device(float64, _sycl_queue.sycl_device)
224224

225225

226-
def from_dlpack(obj):
226+
def from_dlpack(obj, /):
227227
"""
228228
Create a dpnp array from a Python object implementing the ``__dlpack__``
229229
protocol.
@@ -232,19 +232,20 @@ def from_dlpack(obj):
232232
233233
Parameters
234234
----------
235-
obj : A Python object representing an array that implements the ``__dlpack__``
235+
obj : object
236+
A Python object representing an array that implements the ``__dlpack__``
236237
and ``__dlpack_device__`` methods.
237238
238239
Returns
239240
-------
240-
array : dpnp_array
241+
out : dpnp_array
242+
Returns a new dpnp array containing the data from another array
243+
(obj) with the ``__dlpack__`` method on the same device as object.
241244
242245
"""
243246

244247
usm_ary = dpt.from_dlpack(obj)
245-
dpnp_ary = dpnp_array.__new__(dpnp_array)
246-
dpnp_ary._array_obj = usm_ary
247-
return dpnp_ary
248+
return dpnp_array._create_from_usm_ndarray(usm_ary)
248249

249250

250251
def get_dpnp_descriptor(ext_obj,

tests/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get_all_dtypes(no_bool=False,
3232
dtypes.append(dpnp.complex64)
3333
if dev.has_aspect_fp64:
3434
dtypes.append(dpnp.complex128)
35-
35+
3636
# add None value to validate a default dtype
3737
if not no_none:
3838
dtypes.append(None)

tests/test_dparray.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,39 +23,6 @@ def test_astype(arr, arr_dtype, res_dtype):
2323
assert_array_equal(expected, result)
2424

2525

26-
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
27-
@pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)])
28-
def test_from_dlpack(arr_dtype,shape):
29-
X = dpnp.empty(shape=shape,dtype=arr_dtype)
30-
Y = dpnp.from_dlpack(X)
31-
assert_array_equal(X, Y)
32-
assert X.__dlpack_device__() == Y.__dlpack_device__()
33-
assert X.shape == Y.shape
34-
assert X.dtype == Y.dtype or (
35-
str(X.dtype) == "bool" and str(Y.dtype) == "uint8"
36-
)
37-
assert X.sycl_device == Y.sycl_device
38-
assert X.usm_type == Y.usm_type
39-
if Y.ndim:
40-
V = Y[::-1]
41-
W = dpnp.from_dlpack(V)
42-
assert V.strides == W.strides
43-
44-
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
45-
def test_from_dlpack_with_dpt(arr_dtype):
46-
X = dpt.empty((64,),dtype=arr_dtype)
47-
Y = dpnp.from_dlpack(X)
48-
assert_array_equal(X, Y)
49-
assert isinstance(Y, dpnp.dpnp_array.dpnp_array)
50-
assert X.__dlpack_device__() == Y.__dlpack_device__()
51-
assert X.shape == Y.shape
52-
assert X.dtype == Y.dtype or (
53-
str(X.dtype) == "bool" and str(Y.dtype) == "uint8"
54-
)
55-
assert X.sycl_device == Y.sycl_device
56-
assert X.usm_type == Y.usm_type
57-
58-
5926
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
6027
@pytest.mark.parametrize("arr",
6128
[[-2, -1, 0, 1, 2], [[-2, -1], [1, 2]], []],

tests/test_sycl_queue.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import pytest
2+
from .helper import get_all_dtypes
23

34
import dpnp
45
import dpctl
56
import numpy
67

8+
from numpy.testing import (
9+
assert_array_equal
10+
)
11+
712

813
list_of_backend_str = [
914
"host",
@@ -155,7 +160,7 @@ def test_array_creation_like(func, kwargs, device_x, device_y):
155160

156161
dpnp_kwargs = dict(kwargs)
157162
dpnp_kwargs['device'] = device_y
158-
163+
159164
y = getattr(dpnp, func)(x, **dpnp_kwargs)
160165
numpy.testing.assert_array_equal(y_orig, y)
161166
assert_sycl_queue_equal(y.sycl_queue, x.to_device(device_y).sycl_queue)
@@ -637,7 +642,7 @@ def test_eig(device):
637642
dpnp_val_queue = dpnp_val.get_array().sycl_queue
638643
dpnp_vec_queue = dpnp_vec.get_array().sycl_queue
639644

640-
# compare queue and device
645+
# compare queue and device
641646
assert_sycl_queue_equal(dpnp_val_queue, expected_queue)
642647
assert_sycl_queue_equal(dpnp_vec_queue, expected_queue)
643648

@@ -806,3 +811,39 @@ def test_array_copy(device, func, device_param, queue_param):
806811
result = dpnp.array(dpnp_data, **kwargs)
807812

808813
assert_sycl_queue_equal(result.sycl_queue, dpnp_data.sycl_queue)
814+
815+
816+
@pytest.mark.parametrize("device",
817+
valid_devices,
818+
ids=[device.filter_string for device in valid_devices])
819+
#TODO need to delete no_bool=True when use dlpack > 0.7 version
820+
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True, no_bool=True))
821+
@pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)])
822+
def test_from_dlpack(arr_dtype, shape, device):
823+
X = dpnp.empty(shape=shape, dtype=arr_dtype, device=device)
824+
Y = dpnp.from_dlpack(X)
825+
assert_array_equal(X, Y)
826+
assert X.__dlpack_device__() == Y.__dlpack_device__()
827+
assert X.sycl_device == Y.sycl_device
828+
assert X.sycl_context == Y.sycl_context
829+
assert X.usm_type == Y.usm_type
830+
if Y.ndim:
831+
V = Y[::-1]
832+
W = dpnp.from_dlpack(V)
833+
assert V.strides == W.strides
834+
835+
836+
@pytest.mark.parametrize("device",
837+
valid_devices,
838+
ids=[device.filter_string for device in valid_devices])
839+
#TODO need to delete no_bool=True when use dlpack > 0.7 version
840+
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True, no_bool=True))
841+
def test_from_dlpack_with_dpt(arr_dtype, device):
842+
X = dpctl.tensor.empty((64,), dtype=arr_dtype, device=device)
843+
Y = dpnp.from_dlpack(X)
844+
assert_array_equal(X, Y)
845+
assert isinstance(Y, dpnp.dpnp_array.dpnp_array)
846+
assert X.__dlpack_device__() == Y.__dlpack_device__()
847+
assert X.sycl_device == Y.sycl_device
848+
assert X.sycl_context == Y.sycl_context
849+
assert X.usm_type == Y.usm_type

0 commit comments

Comments
 (0)