From 28e5bfe740cc403fb58097dc5a66796ff3474d75 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 24 Aug 2024 11:18:42 +0200 Subject: [PATCH 1/3] Make lower and upper optional argument in Censored --- pymc/distributions/censored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 0d33f06b39..ed11c633a5 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -112,7 +112,7 @@ class Censored(Distribution): rv_op = CensoredRV.rv_op @classmethod - def dist(cls, dist, lower, upper, **kwargs): + def dist(cls, dist, lower=-np.inf, upper=np.inf, **kwargs): if not isinstance(dist, TensorVariable) or not isinstance( dist.owner.op, RandomVariable | SymbolicRandomVariable ): From 7b3539074eade44c99db2636e32a2a857c566a65 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 24 Aug 2024 11:19:01 +0200 Subject: [PATCH 2/3] Remove None as default arguments in ExGaussian --- pymc/distributions/continuous.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 03831ff63a..0af7193b32 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -2941,8 +2941,8 @@ class ExGaussian(Continuous): rv_op = ExGaussianRV.rv_op @classmethod - def dist(cls, mu=0.0, sigma=None, nu=None, *args, **kwargs): - return super().dist([mu, sigma, nu], *args, **kwargs) + def dist(cls, mu=0.0, sigma=1.0, *, nu, **kwargs): + return super().dist([mu, sigma, nu], **kwargs) def support_point(rv, size, mu, sigma, nu): mu, nu, _ = pt.broadcast_arrays(mu, nu, sigma) From ba61d12a19b0cd4d81a9f46e123c908895c43f46 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 24 Aug 2024 11:51:19 +0200 Subject: [PATCH 3/3] Allow truncation of self-contained SymbolicRandomVariables --- pymc/distributions/truncated.py | 20 +++++++++++++------- tests/distributions/test_truncated.py | 21 +++++++++++++++++++-- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index 80dbcbf554..f0200b7368 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -28,7 +28,6 @@ from pytensor.tensor.random.type import RandomType from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform -from pymc.distributions.custom import CustomSymbolicDistRV from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import ( Distribution, @@ -302,17 +301,24 @@ class Truncated(Distribution): def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs): if not ( isinstance(dist, TensorVariable) - and isinstance(dist.owner.op, RandomVariable | CustomSymbolicDistRV) + and dist.owner is not None + and isinstance(dist.owner.op, RandomVariable | SymbolicRandomVariable) ): - if isinstance(dist.owner.op, SymbolicRandomVariable): - raise NotImplementedError( - f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}.\n" - f"You can try wrapping the distribution inside a CustomDist instead." - ) raise ValueError( f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}" ) + if ( + isinstance(dist.owner.op, SymbolicRandomVariable) + and "[size]" not in dist.owner.op.extended_signature + ): + # Truncation needs to wrap the underlying dist, but not all SymbolicRandomVariables encapsulate the whole + # random graph and as such we don't know where the actual inputs begin. This happens mostly for + # distribution factories like `Censored` and `Mixture` which would have a very complex signature if they + # encapsulated the random components instead of taking them as inputs like they do now. + # SymbolicRandomVariables that encapsulate the whole random graph can be identified for having a size parameter. + raise NotImplementedError(f"Truncation not implemented for {dist.owner.op}") + if dist.owner.op.ndim_supp > 0: raise NotImplementedError("Truncation not implemented for multivariate distributions") diff --git a/tests/distributions/test_truncated.py b/tests/distributions/test_truncated.py index cf4824df74..5ef28791d9 100644 --- a/tests/distributions/test_truncated.py +++ b/tests/distributions/test_truncated.py @@ -21,7 +21,7 @@ from pytensor.tensor.random.basic import GeometricRV, NormalRV from pytensor.tensor.random.type import RandomType -from pymc import Model, draw, find_MAP +from pymc import ExGaussian, Model, Normal, draw, find_MAP from pymc.distributions import ( Censored, ChiSquared, @@ -342,7 +342,7 @@ def test_truncation_exceptions(): # Truncation does not work with SymbolicRV inputs with pytest.raises( NotImplementedError, - match="Truncation not implemented for SymbolicRandomVariable CensoredRV", + match="Truncation not implemented for CensoredRV", ): Truncated.dist(Censored.dist(pt.random.normal(), lower=-1, upper=1), -1, 1) @@ -599,3 +599,20 @@ def dist(scale, size): rv_out = Truncated.dist(latent, upper=7) assert np.ptp(draw(rv_out, draws=100)) < 7 + + +@pytest.mark.parametrize( + "dist_fn", + [ + lambda: ExGaussian.dist(nu=3), + pytest.param( + lambda: Censored.dist(Normal.dist(), lower=1), + marks=pytest.mark.xfail(raises=NotImplementedError), + ), + ], +) +def test_truncated_symbolic_rv(dist_fn): + dist = dist_fn() + trunc_dist = Truncated.dist(dist, lower=1, upper=3) + assert 1 <= draw(trunc_dist) <= 3 + assert (logp(trunc_dist, 2.5) > logp(dist, 2.5)).eval()