diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 9ddedb34b1..abbe6de702 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -129,7 +129,8 @@ def explicit_expand_dims( """Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size.""" batch_dims = [ - param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params) + param.type.ndim - ndim_param + for param, ndim_param in zip(params, ndim_params, strict=True) ] if size_length is not None: diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 2160cb83fe..36f111ecdb 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -74,9 +74,9 @@ def test_RandomVariable_basics(strict_test_value_flags): # `dtype` is respected rv = RandomVariable("normal", signature="(),()->()", dtype="int32") with config.change_flags(compute_test_value="off"): - rv_out = rv() + rv_out = rv(0, 0) assert rv_out.dtype == "int32" - rv_out = rv(dtype="int64") + rv_out = rv(0, 0, dtype="int64") assert rv_out.dtype == "int64" with pytest.raises( @@ -85,6 +85,10 @@ def test_RandomVariable_basics(strict_test_value_flags): ): assert rv(dtype="float32").dtype == "float32" + # If we pass fewer arguments (and there are no defaults), an error is raised + with pytest.raises(ValueError): + rv(0) + def test_RandomVariable_bcast(strict_test_value_flags): rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)