Skip to content

Commit 26b8e99

Browse files
Fix test_matmul_simple to avoid use of out-of-bounds Python scalars
Change the test so that input matrices that get multiplied only have blocks of ones no larger than the max integer for the type, rest is populated with zeros. This change applies to integral types only.
1 parent 7ca4395 commit 26b8e99

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

dpctl/tests/test_usm_ndarray_linalg.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,20 @@ def test_matmul_simple(dtype):
8989
skip_if_dtype_not_supported(dtype, q)
9090

9191
n, m = 235, 17
92-
m1 = dpt.ones((m, n), dtype=dtype)
93-
m2 = dpt.ones((n, m), dtype=dtype)
92+
m1 = dpt.zeros((m, n), dtype=dtype)
93+
m2 = dpt.zeros((n, m), dtype=dtype)
94+
95+
dt = m1.dtype
96+
if dt.kind in "ui":
97+
n1 = min(n, dpt.iinfo(dt).max)
98+
else:
99+
n1 = n
100+
m1[:, :n1] = dpt.ones((m, n1), dtype=dt)
101+
m2[:n1, :] = dpt.ones((n1, m), dtype=dt)
94102

95103
for k in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17]:
96104
r = dpt.matmul(m1[:k, :], m2[:, :k])
97-
assert dpt.all(r == dpt.full((k, k), n, dtype=dtype))
105+
assert dpt.all(r == dpt.full((k, k), fill_value=n1, dtype=dt))
98106

99107

100108
@pytest.mark.parametrize("dtype", _numeric_types)

0 commit comments

Comments
 (0)