diff --git a/dpctl/tests/test_usm_ndarray_sorting.py b/dpctl/tests/test_usm_ndarray_sorting.py index a9a2ef1456..fa73dcfdfa 100644 --- a/dpctl/tests/test_usm_ndarray_sorting.py +++ b/dpctl/tests/test_usm_ndarray_sorting.py @@ -275,9 +275,16 @@ def test_sort_complex_fp_nan(dtype): assert np.allclose(dpt.asnumpy(s), expected, equal_nan=True) + pairs = [] for i, j in itertools.permutations(range(inp.shape[0]), 2): - r1 = dpt.asnumpy(dpt.sort(inp[dpt.asarray([i, j])])) - r2 = np.sort(dpt.asnumpy(inp[dpt.asarray([i, j])])) + pairs.append([i, j]) + sub_arrs = inp[dpt.asarray(pairs)] + m1 = dpt.asnumpy(dpt.sort(sub_arrs, axis=1)) + m2 = np.sort(dpt.asnumpy(sub_arrs), axis=1) + for k in range(len(pairs)): + i, j = pairs[k] + r1 = m1[k] + r2 = m2[k] assert np.array_equal( r1.view(np.int64), r2.view(np.int64) ), f"Failed for {i} and {j}"