diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index b6e58688..bb11b7ee 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -6,6 +6,7 @@ from ._lib import Backend, _funcs from ._lib._utils._compat import array_namespace +from ._lib._utils._helpers import asarrays from ._lib._utils._typing import Array __all__ = ["isclose", "pad"] @@ -107,14 +108,11 @@ def isclose( """ xp = array_namespace(a, b) if xp is None else xp - if _delegate( - xp, - Backend.NUMPY, - Backend.CUPY, - Backend.DASK, - Backend.JAX, - Backend.TORCH, - ): + if _delegate(xp, Backend.NUMPY, Backend.CUPY, Backend.DASK, Backend.JAX): + return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + if _delegate(xp, Backend.TORCH): + a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 23344f62..a6b3711b 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -689,7 +689,6 @@ def test_none_shape_bool(self, xp: ModuleType): xp_assert_equal(isclose(a, b), xp.asarray([True, False])) @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp") - @pytest.mark.xfail_xp_backend(Backend.TORCH, reason="Array API 2024.12 support") def test_python_scalar(self, xp: ModuleType): a = xp.asarray([0.0, 0.1], dtype=xp.float32) xp_assert_equal(isclose(a, 0.0), xp.asarray([True, False]))