From 4eb945c9aaf1724cdf729a48f77aa68a95b20471 Mon Sep 17 00:00:00 2001 From: Pamphile Roy Date: Tue, 2 May 2023 17:23:44 +0200 Subject: [PATCH 1/2] ENH: disallow passing array-like to array_namespace --- array_api_compat/common/_helpers.py | 4 +--- tests/test_array_namespace.py | 7 +------ 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index e6adc948..ee742b16 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -71,9 +71,7 @@ def your_function(x, y): """ namespaces = set() for x in xs: - if isinstance(x, (tuple, list)): - namespaces.add(array_namespace(*x, _use_compat=_use_compat)) - elif hasattr(x, '__array_namespace__'): + if hasattr(x, '__array_namespace__'): namespaces.add(x.__array_namespace__(api_version=api_version)) elif _is_numpy_array(x): _check_api_version(api_version) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 806b1192..87d56b00 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -19,15 +19,10 @@ def test_array_namespace(library, api_version): else: assert namespace == getattr(array_api_compat, library) -def test_array_namespace_multiple(): - import numpy as np - - x = np.asarray([1, 2]) - assert array_namespace(x, x) == array_namespace((x, x)) == \ - array_namespace((x, x), x) == array_api_compat.numpy def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) + pytest.raises(TypeError, lambda: array_namespace([1, 2])) pytest.raises(TypeError, lambda: array_namespace()) import numpy as np From 9507a42fd1eee6b7fbf88a75e5fb549b2d32dbf9 Mon Sep 17 00:00:00 2001 From: Pamphile Roy Date: Thu, 4 May 2023 12:36:21 +0200 Subject: [PATCH 2/2] TST: use previous array test --- tests/test_array_namespace.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 87d56b00..1675377d 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -22,18 +22,22 @@ def test_array_namespace(library, api_version): def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) - pytest.raises(TypeError, lambda: array_namespace([1, 2])) pytest.raises(TypeError, lambda: array_namespace()) import numpy as np - import torch x = np.asarray([1, 2]) + + pytest.raises(TypeError, lambda: array_namespace((x, x))) + pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) + + import torch y = torch.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace(x, y)) pytest.raises(ValueError, lambda: array_namespace(x, api_version='2022.12')) + def test_get_namespace(): # Backwards compatible wrapper assert array_api_compat.get_namespace is array_api_compat.array_namespace