Skip to content

Commit 61da03f

Browse files
Implemented from_dlpack(array)
1 parent e98a7d8 commit 61da03f

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,5 +305,18 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
305305

306306
cpdef from_dlpack(array):
307307
"""Constructs `usm_ndarray` from a Python object that implements
308-
`__dlpack__` protocol."""
309-
pass
308+
`__dlpack__` protocol.
309+
"""
310+
if not hasattr(array, "__dlpack__"):
311+
raise TypeError(
312+
"The argument of type {type(array)} does not implement "
313+
"`__dlpack__` method."
314+
)
315+
dlpack_attr = getattr(array, "__dlpack__")
316+
if not callable(dlpack_attr):
317+
raise TypeError(
318+
"The argument of type {type(array)} does not implement "
319+
"`__dlpack__` method."
320+
)
321+
dlpack_capsule = dlpack_attr()
322+
return from_dlpack_capsule(dlpack_capsule)

0 commit comments

Comments
 (0)