Skip to content

Commit 03450eb

Browse files
Add dlpack support with tests and docstrings
1 parent 0f7420e commit 03450eb

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

dpnp/dpnp_array.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __bool__(self):
140140
return self._array_obj.__bool__()
141141

142142
# '__class__',
143-
143+
144144
def __complex__(self):
145145
return self._array_obj.__complex__()
146146

@@ -153,6 +153,12 @@ def __complex__(self):
153153
# '__divmod__',
154154
# '__doc__',
155155

156+
def __dlpack__(self, stream=None):
157+
return self._array_obj.__dlpack__(stream=stream)
158+
159+
def __dlpack_device__(self):
160+
return self._array_obj.__dlpack_device__()
161+
156162
def __eq__(self, other):
157163
return dpnp.equal(self, other)
158164

@@ -190,7 +196,7 @@ def __gt__(self, other):
190196
# '__imatmul__',
191197
# '__imod__',
192198
# '__imul__',
193-
199+
194200
def __index__(self):
195201
return self._array_obj.__index__()
196202

dpnp/dpnp_iface.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"default_float_type",
6565
"dpnp_queue_initialize",
6666
"dpnp_queue_is_cpu",
67+
"from_dlpack",
6768
"get_dpnp_descriptor",
6869
"get_include",
6970
"get_normalized_queue_device"
@@ -222,6 +223,30 @@ def default_float_type(device=None, sycl_queue=None):
222223
return map_dtype_to_device(float64, _sycl_queue.sycl_device)
223224

224225

226+
def from_dlpack(obj):
227+
"""
228+
Create a dpnp array from a Python object implementing the ``__dlpack__``
229+
protocol.
230+
231+
See https://dmlc.github.io/dlpack/latest/ for more details.
232+
233+
Parameters
234+
----------
235+
obj : A Python object representing an array that implements the ``__dlpack__``
236+
and ``__dlpack_device__`` methods.
237+
238+
Returns
239+
-------
240+
array : dpnp_array
241+
242+
"""
243+
244+
usm_ary = dpt.from_dlpack(obj)
245+
dpnp_ary = dpnp_array.__new__(dpnp_array)
246+
dpnp_ary._array_obj = usm_ary
247+
return dpnp_ary
248+
249+
225250
def get_dpnp_descriptor(ext_obj,
226251
copy_when_strides=True,
227252
copy_when_nondefault_queue=True,

tests/test_dparray.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,25 @@ def test_astype(arr, arr_dtype, res_dtype):
2323
assert_array_equal(expected, result)
2424

2525

26+
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
27+
@pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)])
28+
def test_from_dlpack(arr_dtype,shape):
29+
X = dpnp.empty(shape=shape,dtype=arr_dtype)
30+
Y = dpnp.from_dlpack(X)
31+
assert_array_equal(X, Y)
32+
assert X.__dlpack_device__() == Y.__dlpack_device__()
33+
assert X.shape == Y.shape
34+
assert X.dtype == Y.dtype or (
35+
str(X.dtype) == "bool" and str(Y.dtype) == "uint8"
36+
)
37+
assert X.sycl_device == Y.sycl_device
38+
assert X.usm_type == Y.usm_type
39+
if Y.ndim:
40+
V = Y[::-1]
41+
W = dpnp.from_dlpack(V)
42+
assert V.strides == W.strides
43+
44+
2645
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
2746
@pytest.mark.parametrize("arr",
2847
[[-2, -1, 0, 1, 2], [[-2, -1], [1, 2]], []],

0 commit comments

Comments
 (0)