diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 06b39df8e7..9abc480286 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -1795,6 +1795,10 @@ def c_header_dirs(self, **kwargs): return ldflags(libs=False, include_dir=True) def c_code(self, node, name, inp, out, sub): + # Can only compile if linked to blas libraries + if len(self.c_libraries()) <= 0: + raise NotImplementedError() + _x, _y = inp (_z,) = out fail = sub["fail"] diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 34c757dc25..743dc53cc6 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -23,6 +23,7 @@ from pytensor.tensor import inplace from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.blas import ( + BatchedDot, Dot22, Dot22Scalar, Gemm, @@ -2700,6 +2701,30 @@ def check_first_dim(inverted): check_first_dim(inverted) +def test_batched_dot_blas_flags(): + """Test that BatchedDot works regardless of presence of BLAS flags""" + mode = "FAST_RUN" + rng = np.random.default_rng(2708) + + x = tensor("x", shape=(2, 5, 3)) + y = tensor("y", shape=(2, 3, 1)) + out = batched_dot(x, y) + assert isinstance(out.owner.op, BatchedDot) + x_test = rng.normal(size=x.type.shape).astype(x.type.dtype) + y_test = rng.normal(size=y.type.shape).astype(y.type.dtype) + + fn = function([x, y], out, mode=mode) + [batched_dot_thunk] = fn.vm.thunks + assert hasattr(batched_dot_thunk, "cthunk") + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) + + with config.change_flags(blas__ldflags=""): + fn = function([x, y], out, mode=mode) + [batched_dot_thunk] = fn.vm.thunks + assert not hasattr(batched_dot_thunk, "cthunk") + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) + + def test_batched_tensordot(): rng = np.random.default_rng(unittest_tools.fetch_seed()) first = tensor4("first")