From 438310cef14741294ad1f1b61edceedc8361b225 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 19 Mar 2025 15:32:18 +0000 Subject: [PATCH] BUG: `isclose` PyTorch Array API 2024.12 compliance --- src/array_api_extra/_delegation.py | 14 ++++++-------- tests/test_funcs.py | 1 - 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index b6e58688..bb11b7ee 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -6,6 +6,7 @@ from ._lib import Backend, _funcs from ._lib._utils._compat import array_namespace +from ._lib._utils._helpers import asarrays from ._lib._utils._typing import Array __all__ = ["isclose", "pad"] @@ -107,14 +108,11 @@ def isclose( """ xp = array_namespace(a, b) if xp is None else xp - if _delegate( - xp, - Backend.NUMPY, - Backend.CUPY, - Backend.DASK, - Backend.JAX, - Backend.TORCH, - ): + if _delegate(xp, Backend.NUMPY, Backend.CUPY, Backend.DASK, Backend.JAX): + return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + if _delegate(xp, Backend.TORCH): + a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 23344f62..a6b3711b 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -689,7 +689,6 @@ def test_none_shape_bool(self, xp: ModuleType): xp_assert_equal(isclose(a, b), xp.asarray([True, False])) @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp") - @pytest.mark.xfail_xp_backend(Backend.TORCH, reason="Array API 2024.12 support") def test_python_scalar(self, xp: ModuleType): a = xp.asarray([0.0, 0.1], dtype=xp.float32) xp_assert_equal(isclose(a, 0.0), xp.asarray([True, False]))