Skip to content

Commit 0b2180a

Browse files
Merge pull request #906 from IntelPython/dlpack-use-of-stream
Added support for stream other than None
2 parents 3d79ef9 + b79b4b1 commit 0b2180a

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -779,13 +779,14 @@ cdef class usm_ndarray:
779779
NotImplementedError: when non-default value of `stream` keyword
780780
is used.
781781
"""
782-
if stream is None:
783-
return c_dlpack.to_dlpack_capsule(self)
782+
_caps = c_dlpack.to_dlpack_capsule(self)
783+
if (stream is None or type(stream) is not dpctl.SyclQueue or
784+
stream == self.sycl_queue):
785+
pass
784786
else:
785-
raise NotImplementedError(
786-
"Only stream=None is supported. "
787-
"Use `dpctl.SyclQueue.submit_barrier` to synchronize queues."
788-
)
787+
ev = self.sycl_queue.submit_barrier()
788+
stream.submit_barrier(dependent_events=[ev])
789+
return _caps
789790

790791
def __dlpack_device__(self):
791792
"""

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,18 @@ def test_dlpack_exporter(typestr, usm_type):
7979
assert caps_fn(caps2, b"dltensor")
8080

8181

82+
def test_dlpack_exporter_stream():
83+
try:
84+
q1 = dpctl.SyclQueue()
85+
q2 = dpctl.SyclQueue()
86+
except dpctl.SyclQueueCreationError:
87+
pytest.skip("Could not create default queues")
88+
X = dpt.empty((64,), dtype="u1", sycl_queue=q1)
89+
cap1 = X.__dlpack__(stream=q1)
90+
cap2 = X.__dlpack__(stream=q2)
91+
assert type(cap1) is type(cap2)
92+
93+
8294
@pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)])
8395
def test_from_dlpack(shape, typestr, usm_type):
8496
all_root_devices = dpctl.get_devices()

0 commit comments

Comments
 (0)