From b0fa998295bc3634e8b2cf430929455556c39a3b Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 27 Jan 2023 15:47:29 -0600 Subject: [PATCH 1/2] dpctl.tensor.Device.__eq__ and .__hash__ special methods added. Array API's Device object must support __eq__ special method. Addition of __hash__ also allows to use Device object instances as keys in dictionaries. --- dpctl/tensor/_device.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/dpctl/tensor/_device.py b/dpctl/tensor/_device.py index 86b9877806..96185e507d 100644 --- a/dpctl/tensor/_device.py +++ b/dpctl/tensor/_device.py @@ -124,6 +124,18 @@ def wait(self): """ self.sycl_queue_.wait() + def __eq__(self, other): + """Equality comparison based on underlying ``sycl_queue``.""" + if isinstance(other, Device): + return self.sycl_queue.__eq__(other.sycl_queue) + elif isinstance(other, dpctl.SyclQueue): + return self.sycl_queue.__eq__(other) + return False + + def __hash__(self): + """Compute object's hash value.""" + return self.sycl_queue.__hash__() + def normalize_queue_device(sycl_queue=None, device=None): """ From 3dda4f48c6c2cabac43b3f793f69c77fa7ec12c9 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 27 Jan 2023 15:57:33 -0600 Subject: [PATCH 2/2] Added tests for Device.__eq__ and Device.__hash__ --- dpctl/tests/test_usm_ndarray_ctor.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index e3f772086e..d2548355c5 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -1575,3 +1575,19 @@ def test_asarray_uint64(): Xnp = np.ndarray(1, dtype=np.uint64) X = dpt.asarray(Xnp) assert X.dtype == Xnp.dtype + + +def test_Device(): + try: + dev = dpctl.select_default_device() + d1 = dpt.Device.create_device(dev) + d2 = dpt.Device.create_device(dev) + except (dpctl.SyclQueueCreationError, dpctl.SyclDeviceCreationError): + pytest.skip( + "Could not create default device, or a queue that targets it" + ) + assert d1 == d2 + dict = {d1: 1} + assert dict[d2] == 1 + assert d1 == d2.sycl_queue + assert not d1 == Ellipsis