diff --git a/dpctl/tensor/_device.py b/dpctl/tensor/_device.py index 86b9877806..96185e507d 100644 --- a/dpctl/tensor/_device.py +++ b/dpctl/tensor/_device.py @@ -124,6 +124,18 @@ def wait(self): """ self.sycl_queue_.wait() + def __eq__(self, other): + """Equality comparison based on underlying ``sycl_queue``.""" + if isinstance(other, Device): + return self.sycl_queue.__eq__(other.sycl_queue) + elif isinstance(other, dpctl.SyclQueue): + return self.sycl_queue.__eq__(other) + return False + + def __hash__(self): + """Compute object's hash value.""" + return self.sycl_queue.__hash__() + def normalize_queue_device(sycl_queue=None, device=None): """ diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index e3f772086e..d2548355c5 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -1575,3 +1575,19 @@ def test_asarray_uint64(): Xnp = np.ndarray(1, dtype=np.uint64) X = dpt.asarray(Xnp) assert X.dtype == Xnp.dtype + + +def test_Device(): + try: + dev = dpctl.select_default_device() + d1 = dpt.Device.create_device(dev) + d2 = dpt.Device.create_device(dev) + except (dpctl.SyclQueueCreationError, dpctl.SyclDeviceCreationError): + pytest.skip( + "Could not create default device, or a queue that targets it" + ) + assert d1 == d2 + dict = {d1: 1} + assert dict[d2] == 1 + assert d1 == d2.sycl_queue + assert not d1 == Ellipsis