diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 8b7ef18e..538d73ab 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -251,6 +251,8 @@ def mutually_broadcastable_shapes( def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True): shape = draw(square_matrix_shapes) dtype = draw(dtypes) + if not isinstance(finite, bool): + finite = draw(finite) elements = {'allow_nan': False, 'allow_infinity': False} if finite else None a = draw(xps.arrays(dtype=dtype, shape=shape, elements=elements)) upper = xp.triu(a) diff --git a/array_api_tests/meta/test_hypothesis_helpers.py b/array_api_tests/meta/test_hypothesis_helpers.py index b3e5cf3d..6cce7d2a 100644 --- a/array_api_tests/meta/test_hypothesis_helpers.py +++ b/array_api_tests/meta/test_hypothesis_helpers.py @@ -129,11 +129,9 @@ def run(n, d, data): -@given(m=hh.symmetric_matrices(hh.shared_floating_dtypes, - finite=st.shared(st.booleans(), key='finite')), - dtype=hh.shared_floating_dtypes, - finite=st.shared(st.booleans(), key='finite')) -def test_symmetric_matrices(m, dtype, finite): +@given(finite=st.booleans(), dtype=xps.floating_dtypes(), data=st.data()) +def test_symmetric_matrices(finite, dtype, data): + m = data.draw(hh.symmetric_matrices(st.just(dtype), finite=finite)) assert m.dtype == dtype # TODO: This part of this test should be part of the .mT test ah.assert_exactly_equal(m, m.mT)