From 539f1c614f100f19eccbd762c8661da4a7c2575f Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Wed, 24 Jan 2024 12:54:15 -0600 Subject: [PATCH] update_argsort_test --- tests/test_sort.py | 4 ++-- tests/third_party/cupy/sorting_tests/test_sort.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_sort.py b/tests/test_sort.py index 768870f16329..1899604a3042 100644 --- a/tests/test_sort.py +++ b/tests/test_sort.py @@ -94,8 +94,8 @@ def test_argsort_dtype(self, dtype): np_array = numpy.array(a, dtype=dtype) dp_array = dpnp.array(np_array) - result = dpnp.argsort(dp_array) - expected = numpy.argsort(np_array) + result = dpnp.argsort(dp_array, kind="stable") + expected = numpy.argsort(np_array, kind="stable") assert_dtype_allclose(result, expected) @pytest.mark.parametrize("dtype", get_complex_dtypes()) diff --git a/tests/third_party/cupy/sorting_tests/test_sort.py b/tests/third_party/cupy/sorting_tests/test_sort.py index bef8859ba7be..60a48b31e381 100644 --- a/tests/third_party/cupy/sorting_tests/test_sort.py +++ b/tests/third_party/cupy/sorting_tests/test_sort.py @@ -296,12 +296,12 @@ def test_F_order(self, xp): ) ) class TestArgsort(unittest.TestCase): - def argsort(self, a, axis=-1): + def argsort(self, a, axis=-1, kind=None): if self.external: xp = cupy.get_array_module(a) - return xp.argsort(a, axis=axis) + return xp.argsort(a, axis=axis, kind=kind) else: - return a.argsort(axis=axis) + return a.argsort(axis=axis, kind=kind) # Test base cases @@ -317,7 +317,7 @@ def test_argsort_zero_dim(self, xp, dtype): @testing.numpy_cupy_array_equal() def test_argsort_one_dim(self, xp, dtype): a = testing.shaped_random((10,), xp, dtype) - return self.argsort(a) + return self.argsort(a, axis=-1, kind="stable") @testing.for_all_dtypes() @testing.numpy_cupy_array_equal()