Skip to content

Commit d596ff3

Browse files
Raise if Cholesky output has nans when on_error="raise"
1 parent 97c9379 commit d596ff3

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def solve_triangular(a, b):
292292
res = _solve_triangular(a, b, trans, lower, unit_diagonal)
293293
if check_finite:
294294
if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))):
295-
raise ValueError(
295+
raise np.linalg.LinAlgError(
296296
"Non-numeric values (nan or inf) returned by solve_triangular"
297297
)
298298
return res
@@ -349,11 +349,17 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
349349
def numba_funcify_Cholesky(op, node, **kwargs):
350350
lower = op.lower
351351
overwrite_a = False
352-
on_error = op.on_error
352+
check_finite = op.on_error == "raise"
353353

354354
@numba_basic.numba_njit(inline="always")
355355
def nb_cholesky(a):
356-
res = _cholesky(a, lower, overwrite_a, on_error)
356+
res = _cholesky(a, lower, overwrite_a, check_finite)
357+
if check_finite:
358+
if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))):
359+
raise np.linalg.LinAlgError(
360+
"Non-numeric values (nan or inf) returned by cholesky"
361+
)
362+
357363
return res
358364

359365
return nb_cholesky

tests/link/numba/test_slinalg.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def test_solve_triangular_raises_on_nan_inf(value):
102102
b = np.full((5, 1), value)
103103

104104
with pytest.raises(
105-
ValueError, match=re.escape("Non-numeric values (nan or inf) returned ")
105+
np.linalg.LinAlgError,
106+
match=re.escape("Non-numeric values (nan or inf) returned "),
106107
):
107108
f(A_tri, b)
108109

@@ -136,7 +137,7 @@ def test_numba_Cholesky_raises_on_nan():
136137
g = pt.linalg.cholesky(x, on_error="raise")
137138
f = pytensor.function([x], g, mode="NUMBA")
138139

139-
with pytest.raises(ValueError, match=r"Non-numeric values"):
140+
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
140141
f(test_value)
141142

142143

0 commit comments

Comments
 (0)