diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 39513670..5c779302 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -121,7 +121,7 @@ def assert_dtype( >>> assert_dtype('sum', x, out.dtype, default_int) """ - in_dtypes = in_dtype if isinstance(in_dtype, Sequence) else [in_dtype] + in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype] f_in_dtypes = dh.fmt_types(tuple(in_dtypes)) f_out_dtype = dh.dtype_to_name[out_dtype] if expected is None: