diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 70d5d3191c..48f60b76b8 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -779,13 +779,14 @@ cdef class usm_ndarray: NotImplementedError: when non-default value of `stream` keyword is used. """ - if stream is None: - return c_dlpack.to_dlpack_capsule(self) + _caps = c_dlpack.to_dlpack_capsule(self) + if (stream is None or type(stream) is not dpctl.SyclQueue or + stream == self.sycl_queue): + pass else: - raise NotImplementedError( - "Only stream=None is supported. " - "Use `dpctl.SyclQueue.submit_barrier` to synchronize queues." - ) + ev = self.sycl_queue.submit_barrier() + stream.submit_barrier(dependent_events=[ev]) + return _caps def __dlpack_device__(self): """ diff --git a/dpctl/tests/test_usm_ndarray_dlpack.py b/dpctl/tests/test_usm_ndarray_dlpack.py index f86a4ce05f..a329fecc80 100644 --- a/dpctl/tests/test_usm_ndarray_dlpack.py +++ b/dpctl/tests/test_usm_ndarray_dlpack.py @@ -79,6 +79,18 @@ def test_dlpack_exporter(typestr, usm_type): assert caps_fn(caps2, b"dltensor") +def test_dlpack_exporter_stream(): + try: + q1 = dpctl.SyclQueue() + q2 = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Could not create default queues") + X = dpt.empty((64,), dtype="u1", sycl_queue=q1) + cap1 = X.__dlpack__(stream=q1) + cap2 = X.__dlpack__(stream=q2) + assert type(cap1) is type(cap2) + + @pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)]) def test_from_dlpack(shape, typestr, usm_type): all_root_devices = dpctl.get_devices()