Skip to content

Commit 23e418d

Browse files
Implement default_transform and transform argument for distributions (#7207)
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent ed62da1 commit 23e418d

File tree

9 files changed

+126
-38
lines changed

9 files changed

+126
-38
lines changed

pymc/distributions/distribution.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def __new__(
475475
observed=None,
476476
total_size=None,
477477
transform=UNSET,
478+
default_transform=UNSET,
478479
**kwargs,
479480
) -> TensorVariable:
480481
"""Adds a tensor variable corresponding to a PyMC distribution to the current model.
@@ -556,10 +557,11 @@ def __new__(
556557
rv_out = model.register_rv(
557558
rv_out,
558559
name,
559-
observed,
560-
total_size,
560+
observed=observed,
561+
total_size=total_size,
561562
dims=dims,
562563
transform=transform,
564+
default_transform=default_transform,
563565
initval=initval,
564566
)
565567

pymc/model/core.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from sys import modules
2323
from typing import (
2424
TYPE_CHECKING,
25-
Any,
2625
Literal,
2726
Optional,
2827
TypeVar,
@@ -48,7 +47,7 @@
4847

4948
from pymc.blocking import DictToArrayBijection, RaveledVars
5049
from pymc.data import GenTensorVariable, is_minibatch
51-
from pymc.distributions.transforms import _default_transform
50+
from pymc.distributions.transforms import ChainedTransform, _default_transform
5251
from pymc.exceptions import (
5352
BlockModelAccessError,
5453
ImputationWarning,
@@ -58,6 +57,7 @@
5857
)
5958
from pymc.initial_point import make_initial_point_fn
6059
from pymc.logprob.basic import transformed_conditional_logp
60+
from pymc.logprob.transforms import Transform
6161
from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values
6262
from pymc.model_graph import model_to_graphviz
6363
from pymc.pytensorf import (
@@ -1214,7 +1214,16 @@ def set_data(
12141214
shared_object.set_value(values)
12151215

12161216
def register_rv(
1217-
self, rv_var, name, observed=None, total_size=None, dims=None, transform=UNSET, initval=None
1217+
self,
1218+
rv_var,
1219+
name,
1220+
*,
1221+
observed=None,
1222+
total_size=None,
1223+
dims=None,
1224+
default_transform=UNSET,
1225+
transform=UNSET,
1226+
initval=None,
12181227
):
12191228
"""Register an (un)observed random variable with the model.
12201229
@@ -1229,8 +1238,10 @@ def register_rv(
12291238
upscales logp of variable with ``coef = total_size/var.shape[0]``
12301239
dims : tuple
12311240
Dimension names for the variable.
1241+
default_transform
1242+
A default transform for the random variable in log-likelihood space.
12321243
transform
1233-
A transform for the random variable in log-likelihood space.
1244+
Additional transform which may be applied after default transform.
12341245
initval
12351246
The initial value of the random variable.
12361247
@@ -1255,7 +1266,7 @@ def register_rv(
12551266
if total_size is not None:
12561267
raise ValueError("total_size can only be passed to observed RVs")
12571268
self.free_RVs.append(rv_var)
1258-
self.create_value_var(rv_var, transform)
1269+
self.create_value_var(rv_var, transform=transform, default_transform=default_transform)
12591270
self.add_named_variable(rv_var, dims)
12601271
self.set_initval(rv_var, initval)
12611272
else:
@@ -1278,7 +1289,9 @@ def register_rv(
12781289

12791290
# `rv_var` is potentially changed by `make_obs_var`,
12801291
# for example into a new graph for imputation of missing data.
1281-
rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size)
1292+
rv_var = self.make_obs_var(
1293+
rv_var, observed, dims, default_transform, transform, total_size
1294+
)
12821295

12831296
return rv_var
12841297

@@ -1287,7 +1300,8 @@ def make_obs_var(
12871300
rv_var: TensorVariable,
12881301
data: np.ndarray,
12891302
dims,
1290-
transform: Any | None,
1303+
default_transform: Transform | None,
1304+
transform: Transform | None,
12911305
total_size: int | None,
12921306
) -> TensorVariable:
12931307
"""Create a `TensorVariable` for an observed random variable.
@@ -1301,8 +1315,10 @@ def make_obs_var(
13011315
The observed data.
13021316
dims : tuple
13031317
Dimension names for the variable.
1304-
transform : int, optional
1318+
default_transform
13051319
A transform for the random variable in log-likelihood space.
1320+
transform
1321+
Additional transform which may be applied after default transform.
13061322
13071323
Returns
13081324
-------
@@ -1339,12 +1355,19 @@ def make_obs_var(
13391355

13401356
# Register ObservedRV corresponding to observed component
13411357
observed_rv.name = f"{name}_observed"
1342-
self.create_value_var(observed_rv, transform=None, value_var=observed_data)
1358+
self.create_value_var(
1359+
observed_rv, transform=transform, default_transform=None, value_var=observed_data
1360+
)
13431361
self.add_named_variable(observed_rv)
13441362
self.observed_RVs.append(observed_rv)
13451363

13461364
# Register FreeRV corresponding to unobserved components
1347-
self.register_rv(unobserved_rv, f"{name}_unobserved", transform=transform)
1365+
self.register_rv(
1366+
unobserved_rv,
1367+
f"{name}_unobserved",
1368+
transform=transform,
1369+
default_transform=default_transform,
1370+
)
13481371

13491372
# Register Deterministic that combines observed and missing
13501373
# Note: This can widely increase memory consumption during sampling for large datasets
@@ -1363,14 +1386,21 @@ def make_obs_var(
13631386
rv_var.name = name
13641387

13651388
rv_var.tag.observations = data
1366-
self.create_value_var(rv_var, transform=None, value_var=data)
1389+
self.create_value_var(
1390+
rv_var, transform=transform, default_transform=None, value_var=data
1391+
)
13671392
self.add_named_variable(rv_var, dims)
13681393
self.observed_RVs.append(rv_var)
13691394

13701395
return rv_var
13711396

13721397
def create_value_var(
1373-
self, rv_var: TensorVariable, transform: Any, value_var: Variable | None = None
1398+
self,
1399+
rv_var: TensorVariable,
1400+
*,
1401+
default_transform: Transform,
1402+
transform: Transform,
1403+
value_var: Variable | None = None,
13741404
) -> TensorVariable:
13751405
"""Create a ``TensorVariable`` that will be used as the random
13761406
variable's "value" in log-likelihood graphs.
@@ -1385,7 +1415,11 @@ def create_value_var(
13851415
----------
13861416
rv_var : TensorVariable
13871417
1388-
transform : Any
1418+
default_transform: Transform
1419+
A transform for the random variable in log-likelihood space.
1420+
1421+
transform: Transform
1422+
Additional transform which may be applied after default transform.
13891423
13901424
value_var : Variable, optional
13911425
@@ -1396,11 +1430,25 @@ def create_value_var(
13961430

13971431
# Make the value variable a transformed value variable,
13981432
# if there's an applicable transform
1399-
if transform is UNSET:
1433+
if transform is None and default_transform is UNSET:
1434+
default_transform = None
1435+
warnings.warn(
1436+
"To disable default transform, please use default_transform=None"
1437+
" instead of transform=None. Setting transform to None will"
1438+
" not have any effect in future.",
1439+
UserWarning,
1440+
)
1441+
1442+
if default_transform is UNSET:
14001443
if rv_var.owner is None:
1401-
transform = None
1444+
default_transform = None
14021445
else:
1403-
transform = _default_transform(rv_var.owner.op, rv_var)
1446+
default_transform = _default_transform(rv_var.owner.op, rv_var)
1447+
1448+
if transform is UNSET:
1449+
transform = default_transform
1450+
elif transform is not None and default_transform is not None:
1451+
transform = ChainedTransform([default_transform, transform])
14041452

14051453
if value_var is None:
14061454
if transform is None:

pymc/model/fgraph.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,14 @@ def first_non_model_var(var):
320320
var, value, *dims = model_var.owner.inputs
321321
transform = model_var.owner.op.transform
322322
model.free_RVs.append(var)
323-
model.create_value_var(var, transform=transform, value_var=value)
323+
model.create_value_var(
324+
var, transform=transform, default_transform=None, value_var=value
325+
)
324326
model.set_initval(var, initval=None)
325327
elif isinstance(model_var.owner.op, ModelObservedRV):
326328
var, value, *dims = model_var.owner.inputs
327329
model.observed_RVs.append(var)
328-
model.create_value_var(var, transform=None, value_var=value)
330+
model.create_value_var(var, transform=None, default_transform=None, value_var=value)
329331
elif isinstance(model_var.owner.op, ModelPotential):
330332
var, *dims = model_var.owner.inputs
331333
model.potentials.append(var)

tests/distributions/test_mixture.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,17 +1359,17 @@ def test_warning(self):
13591359

13601360
with warnings.catch_warnings():
13611361
warnings.simplefilter("error")
1362-
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, transform=None)
1362+
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, default_transform=None)
13631363

13641364
with warnings.catch_warnings():
13651365
warnings.simplefilter("error")
1366-
Mixture("mix5", w=[0.5, 0.5], comp_dists=comp_dists, observed=1)
1366+
Mixture("mix6", w=[0.5, 0.5], comp_dists=comp_dists, observed=1)
13671367

13681368
# Case where the appropriate default transform is None
13691369
comp_dists = [Normal.dist(), Normal.dist()]
13701370
with warnings.catch_warnings():
13711371
warnings.simplefilter("error")
1372-
Mixture("mix6", w=[0.5, 0.5], comp_dists=comp_dists)
1372+
Mixture("mix7", w=[0.5, 0.5], comp_dists=comp_dists)
13731373

13741374

13751375
class TestZeroInflatedMixture:

tests/distributions/test_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def test_transform_univariate_dist_logp_shape():
619619

620620
def test_univariate_transform_multivariate_dist_raises():
621621
with pm.Model() as m:
622-
pm.Dirichlet("x", [1, 1, 1], transform=tr.log)
622+
pm.Dirichlet("x", [1, 1, 1], default_transform=tr.log)
623623

624624
for jacobian_val in (True, False):
625625
with pytest.raises(
@@ -645,7 +645,7 @@ def log_jac_det(self, value, *inputs):
645645
buggy_transform = BuggyTransform()
646646

647647
with pm.Model() as m:
648-
pm.Uniform("x", shape=(4, 3), transform=buggy_transform)
648+
pm.Uniform("x", shape=(4, 3), default_transform=buggy_transform)
649649

650650
for jacobian_val in (True, False):
651651
with pytest.raises(

tests/logprob/test_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,11 @@ def test_interdependent_transformed_rvs(self, reversed):
218218
transform = pm.distributions.transforms.Interval(
219219
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1])
220220
)
221-
x = pm.Uniform("x", lower=0, upper=1, transform=transform)
221+
x = pm.Uniform("x", lower=0, upper=1, default_transform=transform)
222222
# Operation between the variables provides a regression test for #7054
223-
y = pm.Uniform("y", lower=0, upper=pt.exp(x), transform=transform)
224-
z = pm.Uniform("z", lower=0, upper=y, transform=transform)
225-
w = pm.Uniform("w", lower=0, upper=pt.square(z), transform=transform)
223+
y = pm.Uniform("y", lower=0, upper=pt.exp(x), default_transform=transform)
224+
z = pm.Uniform("z", lower=0, upper=y, default_transform=transform)
225+
w = pm.Uniform("w", lower=0, upper=pt.square(z), default_transform=transform)
226226

227227
rvs = [x, y, z, w]
228228
if reversed:

tests/model/test_core.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,14 @@
4242
from pymc.blocking import DictToArrayBijection, RaveledVars
4343
from pymc.distributions import Normal, transforms
4444
from pymc.distributions.distribution import PartialObservedRV
45-
from pymc.distributions.transforms import log, simplex
45+
from pymc.distributions.transforms import (
46+
ChainedTransform,
47+
Interval,
48+
LogTransform,
49+
log,
50+
ordered,
51+
simplex,
52+
)
4653
from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning
4754
from pymc.logprob.basic import transformed_conditional_logp
4855
from pymc.logprob.transforms import IntervalTransform
@@ -527,6 +534,35 @@ def test_model_var_maps():
527534
assert model.rvs_to_transforms[x] is None
528535

529536

537+
class TestTransformArgs:
538+
def test_transform_warning(self):
539+
with pm.Model():
540+
with pytest.warns(
541+
UserWarning,
542+
match="To disable default transform,"
543+
" please use default_transform=None"
544+
" instead of transform=None. Setting transform to"
545+
" None will not have any effect in future.",
546+
):
547+
a = pm.Normal("a", transform=None)
548+
549+
def test_transform_order(self):
550+
with pm.Model() as model:
551+
x = pm.Normal("x", transform=Interval(0, 1), default_transform=log)
552+
transform = model.rvs_to_transforms[x]
553+
assert isinstance(transform, ChainedTransform)
554+
assert isinstance(transform.transform_list[0], LogTransform)
555+
assert isinstance(transform.transform_list[1], Interval)
556+
557+
def test_default_transform_is_applied(self):
558+
with pm.Model() as model1:
559+
x1 = pm.LogNormal("x1", [0, 0], [1, 1], transform=ordered, default_transform=None)
560+
with pm.Model() as model2:
561+
x2 = pm.LogNormal("x2", [0, 0], [1, 1], transform=ordered)
562+
assert np.isinf(model1.compile_logp()({"x1_ordered__": (-1, -1)}))
563+
assert np.isfinite(model2.compile_logp()({"x2_chain__": (-1, -1)}))
564+
565+
530566
def test_make_obs_var():
531567
"""
532568
Check returned values for `data` given known inputs to `as_tensor()`.
@@ -549,26 +585,26 @@ def test_make_obs_var():
549585

550586
# The function requires data and RV dimensionality to be compatible
551587
with pytest.raises(ShapeError, match="Dimensionality of data and RV don't match."):
552-
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None, None)
588+
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None, None, None)
553589

554590
# Check function behavior using the various inputs
555591
# dense, sparse: Ensure that the missing values are appropriately set to None
556592
# masked: a deterministic variable is returned
557593

558-
dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None, None)
594+
dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None, None, None)
559595
assert dense_output == fake_distribution
560596
assert isinstance(fake_model.rvs_to_values[dense_output], TensorConstant)
561597
del fake_model.named_vars[fake_distribution.name]
562598

563-
sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None, None)
599+
sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None, None, None)
564600
assert sparse_output == fake_distribution
565601
assert sparse.basic._is_sparse_variable(fake_model.rvs_to_values[sparse_output])
566602
del fake_model.named_vars[fake_distribution.name]
567603

568604
# Here the RandomVariable is split into observed/imputed and a Deterministic is returned
569605
with pytest.warns(ImputationWarning):
570606
masked_output = fake_model.make_obs_var(
571-
fake_distribution, masked_array_input, None, None, None
607+
fake_distribution, masked_array_input, None, None, None, None
572608
)
573609
assert masked_output != fake_distribution
574610
assert not isinstance(masked_output, RandomVariable)
@@ -581,7 +617,7 @@ def test_make_obs_var():
581617

582618
# Test that setting total_size returns a MinibatchRandomVariable
583619
scaled_outputs = fake_model.make_obs_var(
584-
fake_distribution, dense_input, None, None, total_size=100
620+
fake_distribution, dense_input, None, None, None, total_size=100
585621
)
586622
assert scaled_outputs != fake_distribution
587623
assert isinstance(scaled_outputs.owner.op, MinibatchRandomVariable)

tests/model/transform/test_conditioning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ def test_change_value_transforms_error():
286286

287287
def test_remove_value_transforms():
288288
with pm.Model() as base_m:
289-
p = pm.Uniform("p", transform=logodds)
290-
q = pm.Uniform("q", transform=logodds)
289+
p = pm.Uniform("p", transform=logodds, default_transform=None)
290+
q = pm.Uniform("q", transform=logodds, default_transform=None)
291291

292292
new_m = remove_value_transforms(base_m)
293293
new_p = new_m["p"]

tests/sampling/test_mcmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def test_transform_with_rv_dependency(self, symbolic_rv):
303303
transform = pm.distributions.transforms.Interval(
304304
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1])
305305
)
306-
y = pm.Uniform("y", lower=0, upper=x, transform=transform)
306+
y = pm.Uniform("y", lower=0, upper=x, transform=transform, default_transform=None)
307307
with warnings.catch_warnings():
308308
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
309309
trace = pm.sample(tune=10, draws=50, return_inferencedata=False, random_seed=336)

0 commit comments

Comments
 (0)