From 3e6c9b68a112ac1520cd356b8a5949fabd217986 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 Jul 2024 16:56:16 -0500 Subject: [PATCH 1/2] Resolves gh-1738 The unique_inverse and unique_all now always return inverse_index data fields in default indexing data type as per Python Array API specification. --- dpctl/tensor/_set_functions.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/dpctl/tensor/_set_functions.py b/dpctl/tensor/_set_functions.py index 2e2df751a9..bbba301da4 100644 --- a/dpctl/tensor/_set_functions.py +++ b/dpctl/tensor/_set_functions.py @@ -425,8 +425,7 @@ def unique_inverse(x): ) _manager.add_event_pair(ht_ev, sub_ev) - inv_dt = dpt.int64 if x.size > dpt.iinfo(dpt.int32).max else dpt.int32 - inv = dpt.empty_like(x, dtype=inv_dt, order="C") + inv = dpt.empty_like(x, dtype=ind_dt, order="C") ht_ev, ssl_ev = _searchsorted_left( hay=unique_vals, needles=x, @@ -608,8 +607,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult: ) _manager.add_event_pair(ht_ev, sub_ev) - inv_dt = dpt.int64 if x.size > dpt.iinfo(dpt.int32).max else dpt.int32 - inv = dpt.empty_like(x, dtype=inv_dt, order="C") + inv = dpt.empty_like(x, dtype=ind_dt, order="C") ht_ev, ssl_ev = _searchsorted_left( hay=unique_vals, needles=x, From 7e2c43c9b171f2d33302ab17779ee06c5fc017fa Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 Jul 2024 16:57:20 -0500 Subject: [PATCH 2/2] Add test based on example provided in gh-1738 --- dpctl/tests/test_usm_ndarray_unique.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_unique.py b/dpctl/tests/test_usm_ndarray_unique.py index f3504ee032..fcd55fdfc1 100644 --- a/dpctl/tests/test_usm_ndarray_unique.py +++ b/dpctl/tests/test_usm_ndarray_unique.py @@ -321,3 +321,25 @@ def test_set_functions_compute_follows_data(): assert ind.sycl_queue == q assert inv_ind.sycl_queue == q assert uc.sycl_queue == q + + +def test_gh_1738(): + get_queue_or_skip() + + ones = dpt.ones(10, dtype="i8") + iota = dpt.arange(10, dtype="i8") + + assert ones.device == iota.device + + dpt_info = dpt.__array_namespace_info__() + ind_dt = dpt_info.default_dtypes(device=ones.device)["indexing"] + + dt = dpt.unique_inverse(ones).inverse_indices.dtype + assert dt == ind_dt + dt = dpt.unique_all(ones).inverse_indices.dtype + assert dt == ind_dt + + dt = dpt.unique_inverse(iota).inverse_indices.dtype + assert dt == ind_dt + dt = dpt.unique_all(iota).inverse_indices.dtype + assert dt == ind_dt