From b3a9ef427d40b0fad5879a8bdd347c931a3d40ed Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 11 Apr 2024 10:28:58 +0100 Subject: [PATCH] Fallback to `xp.bool_` --- array_api_tests/__init__.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index 9af6796b..fd05697d 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -11,7 +11,7 @@ # You can comment the following out and instead import the specific array module -# you want to test, e.g. `import numpy.array_api as xp`. +# you want to test, e.g. `import array_api_strict as xp`. if "ARRAY_API_TESTS_MODULE" in os.environ: xp_name = os.environ["ARRAY_API_TESTS_MODULE"] _module, _sub = xp_name, None @@ -33,6 +33,17 @@ ) +# If xp.bool is not available, like in some versions of NumPy and CuPy, try +# patching in xp.bool_. +try: + xp.bool +except AttributeError as e: + if hasattr(xp, "bool_"): + xp.bool = xp.bool_ + else: + raise e + + # We monkey patch floats() to always disable subnormals as they are out-of-scope _floats = st.floats