diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 0c2a43b1f1..62a9d4c375 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -107,7 +107,8 @@ def backward(self, value, *inputs): def forward(self, value, *inputs): y = pt.zeros(value.shape) y = pt.set_subtensor(y[..., 0], value[..., 0]) - y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1])) + log_value = pt.log(value[..., 1:] - value[..., :-1]) + y = pt.set_subtensor(y[..., 1:], log_value) return y def log_jac_det(self, value, *inputs): @@ -116,8 +117,9 @@ def log_jac_det(self, value, *inputs): class SumTo1(Transform): """ - Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of values in [0,1] - This Transformation operates on the last dimension of the input tensor. + Transforms K - 1 dimensional simplex space (k values in [0,1] and that + sum to 1) to a K - 1 vector of values in [0,1]. This Transformation + operates on the last dimension of the input tensor. """ name = "sumto1" @@ -140,7 +142,8 @@ def log_jac_det(self, value, *inputs): class CholeskyCovPacked(Transform): """ - Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the + Transforms the diagonal elements of + the LKJCholeskyCov distribution to be on the log scale """ @@ -157,10 +160,14 @@ def __init__(self, n): self.diag_idxs = pt.arange(1, n + 1).cumsum() - 1 def backward(self, value, *inputs): - return pt.set_subtensor(value[..., self.diag_idxs], pt.exp(value[..., self.diag_idxs])) + diag_values = value[..., self.diag_idxs] + exp_values = pt.exp(diag_values) + return pt.set_subtensor(value[..., self.diag_idxs], exp_values) def forward(self, value, *inputs): - return pt.set_subtensor(value[..., self.diag_idxs], pt.log(value[..., self.diag_idxs])) + diag_values = value[..., self.diag_idxs] + log_values = pt.log(diag_values) + return pt.set_subtensor(value[..., self.diag_idxs], log_values) def log_jac_det(self, value, *inputs): return pt.sum(value[..., self.diag_idxs], axis=-1) @@ -180,8 +187,9 @@ def log_jac_det(self, value, *inputs): class Interval(IntervalTransform): - """Wrapper around :class:`pymc.logprob.transforms.IntervalTransform` for use in the - ``transform`` argument of a random variable. + """ + Wrapper around :class:`pymc.logprob.transforms.IntervalTransform` for use + in the ``transform`` argument of a random variable. Parameters ---------- @@ -192,15 +200,15 @@ class Interval(IntervalTransform): Upper bound of the interval transform. Must be a constant finite value. By default (``upper=None``), the interval is not bounded above. bounds_fn : callable, optional - Alternative to lower and upper. Must return a tuple of lower and upper bounds - as a symbolic function of the respective distribution inputs. If one of lower or - upper is ``None``, the interval is unbounded on that edge. - - .. warning:: Expressions returned by `bounds_fn` should depend only on the - distribution inputs or other constants. Expressions that depend on nonlocal - variables, such as other distributions defined in the model context will - likely break sampling. + Alternative to lower and upper. Must return a tuple of lower and upper + bounds as a symbolic function of the respective distribution inputs. If + one of lower or upper is ``None``,the interval is unbounded on + that edge. + .. warning:: Expressions returned by `bounds_fn` should depend only on + the distribution inputs or other constants. Expressions that depend + on nonlocal variables, such as other distributions defined in the + model context will likely break sampling. Examples -------- @@ -220,10 +228,14 @@ def get_bounds(rng, size, mu, sigma): return 0, None with pm.Model(): - interval = pm.distributions.transforms.Interval(bounds_fn=get_bounds) + interval = pm.distributions.transforms.Interval( + bounds_fn=get_bounds + ) + x = pm.Normal("x", transform=interval) - Create a lower-bounded interval transform that depends on a distribution parameter + Create a lower-bounded interval transform that depends on a + distribution parameter .. code-block:: python @@ -267,10 +279,47 @@ class ZeroSumTransform(Transform): """ Constrains any random samples to sum to zero along the user-provided ``zerosum_axes``. + This transform is useful when modeling distributions where the sum of certain dimensions + must be zero, such as in some types of constrained latent variable models or in certain + types of signal processing applications. + Parameters ---------- - zerosum_axes : list of ints - Must be a list of integers (positive or negative). + zerosum_axes : list of int + List of integers specifying the axes along which the random samples should sum to zero. + Positive integers indicate dimensions in the standard order, while negative integers + can be used to reference dimensions from the end of the shape. + + Examples + -------- + Suppose you want to ensure that the last dimension of a tensor sums to zero. You can use + `ZeroSumTransform` as follows: + + .. code-block:: python + + import pymc as pm + + with pm.Model() as model: + # Create a 2D variable with the last axis constrained to sum to zero + x = pm.Normal("x", shape=(10, 5), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[-1])) + + Methods + ------- + forward(value, *rv_inputs) + Transforms the input tensor to ensure that the specified axes sum to zero. + + backward(value, *rv_inputs) + Computes the inverse transform to convert back to the original space where the sum was zero. + + log_jac_det(value, *rv_inputs) + Returns the log Jacobian determinant of the transform. For this transform, it is zero. + + Notes + ----- + The `extend_axis` and `extend_axis_rev` methods are used internally to handle the transformation: + - `extend_axis`: Extends the axis by adding an additional element to ensure zero-sum constraint. + - `extend_axis_rev`: Reverses the extension operation applied by `extend_axis`. + """ name = "zerosum"