From 657e20859bb721cf57c137d3d922e58bef21f580 Mon Sep 17 00:00:00 2001 From: Siddharth Baleja Date: Mon, 29 Jul 2024 02:45:03 +0530 Subject: [PATCH 1/2] Add ZeroSumNormal distribution --- pymc/distributions/transforms.py | 108 ++++++++++++++++++++++++------- 1 file changed, 83 insertions(+), 25 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 0c2a43b1f1..3466436a98 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -53,15 +53,20 @@ def __getattr__(name): if name in ("univariate_ordered", "multivariate_ordered"): - warnings.warn(f"{name} has been deprecated, use ordered instead.", FutureWarning) + warnings.warn( + f"{name} has been deprecated, use ordered instead.", + FutureWarning) return ordered if name in ("univariate_sum_to_1", "multivariate_sum_to_1"): - warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning) + warnings.warn( + f"{name} has been deprecated, use sum_to_1 instead.", + FutureWarning) return sum_to_1 if name == "RVTransform": - warnings.warn("RVTransform has been renamed to Transform", FutureWarning) + warnings.warn( + "RVTransform has been renamed to Transform", FutureWarning) return Transform raise AttributeError(f"module {__name__} has no attribute {name}") @@ -96,7 +101,9 @@ class Ordered(Transform): def __init__(self, ndim_supp=None): if ndim_supp is not None: - warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) + warnings.warn( + "ndim_supp argument is deprecated and has no effect", + FutureWarning) def backward(self, value, *inputs): x = pt.zeros(value.shape) @@ -107,7 +114,8 @@ def backward(self, value, *inputs): def forward(self, value, *inputs): y = pt.zeros(value.shape) y = pt.set_subtensor(y[..., 0], value[..., 0]) - y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1])) + log_value = pt.log(value[..., 1:] - value[..., :-1]) + y = pt.set_subtensor(y[..., 1:], log_value) return y def log_jac_det(self, value, *inputs): @@ -116,15 +124,18 @@ def log_jac_det(self, value, *inputs): class SumTo1(Transform): """ - Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of values in [0,1] - This Transformation operates on the last dimension of the input tensor. + Transforms K - 1 dimensional simplex space (k values in [0,1] and that + sum to 1) to a K - 1 vector of values in [0,1]. This Transformation + operates on the last dimension of the input tensor. """ name = "sumto1" def __init__(self, ndim_supp=None): if ndim_supp is not None: - warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) + warnings.warn( + "ndim_supp argument is deprecated and has no effect", + FutureWarning) def backward(self, value, *inputs): remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True) @@ -140,7 +151,8 @@ def log_jac_det(self, value, *inputs): class CholeskyCovPacked(Transform): """ - Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the + Transforms the diagonal elements of + the LKJCholeskyCov distribution to be on the log scale """ @@ -157,10 +169,14 @@ def __init__(self, n): self.diag_idxs = pt.arange(1, n + 1).cumsum() - 1 def backward(self, value, *inputs): - return pt.set_subtensor(value[..., self.diag_idxs], pt.exp(value[..., self.diag_idxs])) + diag_values = value[..., self.diag_idxs] + exp_values = pt.exp(diag_values) + return pt.set_subtensor(value[..., self.diag_idxs], exp_values) def forward(self, value, *inputs): - return pt.set_subtensor(value[..., self.diag_idxs], pt.log(value[..., self.diag_idxs])) + diag_values = value[..., self.diag_idxs] + log_values = pt.log(diag_values) + return pt.set_subtensor(value[..., self.diag_idxs], log_values) def log_jac_det(self, value, *inputs): return pt.sum(value[..., self.diag_idxs], axis=-1) @@ -180,8 +196,9 @@ def log_jac_det(self, value, *inputs): class Interval(IntervalTransform): - """Wrapper around :class:`pymc.logprob.transforms.IntervalTransform` for use in the - ``transform`` argument of a random variable. + """ + Wrapper around :class:`pymc.logprob.transforms.IntervalTransform` for use + in the ``transform`` argument of a random variable. Parameters ---------- @@ -192,15 +209,15 @@ class Interval(IntervalTransform): Upper bound of the interval transform. Must be a constant finite value. By default (``upper=None``), the interval is not bounded above. bounds_fn : callable, optional - Alternative to lower and upper. Must return a tuple of lower and upper bounds - as a symbolic function of the respective distribution inputs. If one of lower or - upper is ``None``, the interval is unbounded on that edge. - - .. warning:: Expressions returned by `bounds_fn` should depend only on the - distribution inputs or other constants. Expressions that depend on nonlocal - variables, such as other distributions defined in the model context will - likely break sampling. + Alternative to lower and upper. Must return a tuple of lower and upper + bounds as a symbolic function of the respective distribution inputs. If + one of lower or upper is ``None``,the interval is unbounded on + that edge. + .. warning:: Expressions returned by `bounds_fn` should depend only on + the distribution inputs or other constants. Expressions that depend + on nonlocal variables, such as other distributions defined in the + model context will likely break sampling. Examples -------- @@ -220,10 +237,14 @@ def get_bounds(rng, size, mu, sigma): return 0, None with pm.Model(): - interval = pm.distributions.transforms.Interval(bounds_fn=get_bounds) + interval = pm.distributions.transforms.Interval( + bounds_fn=get_bounds + ) + x = pm.Normal("x", transform=interval) - Create a lower-bounded interval transform that depends on a distribution parameter + Create a lower-bounded interval transform that depends on a + distribution parameter .. code-block:: python @@ -267,10 +288,47 @@ class ZeroSumTransform(Transform): """ Constrains any random samples to sum to zero along the user-provided ``zerosum_axes``. + This transform is useful when modeling distributions where the sum of certain dimensions + must be zero, such as in some types of constrained latent variable models or in certain + types of signal processing applications. + Parameters ---------- - zerosum_axes : list of ints - Must be a list of integers (positive or negative). + zerosum_axes : list of int + List of integers specifying the axes along which the random samples should sum to zero. + Positive integers indicate dimensions in the standard order, while negative integers + can be used to reference dimensions from the end of the shape. + + Examples + -------- + Suppose you want to ensure that the last dimension of a tensor sums to zero. You can use + `ZeroSumTransform` as follows: + + .. code-block:: python + + import pymc as pm + + with pm.Model() as model: + # Create a 2D variable with the last axis constrained to sum to zero + x = pm.Normal("x", shape=(10, 5), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[-1])) + + Methods + ------- + forward(value, *rv_inputs) + Transforms the input tensor to ensure that the specified axes sum to zero. + + backward(value, *rv_inputs) + Computes the inverse transform to convert back to the original space where the sum was zero. + + log_jac_det(value, *rv_inputs) + Returns the log Jacobian determinant of the transform. For this transform, it is zero. + + Notes + ----- + The `extend_axis` and `extend_axis_rev` methods are used internally to handle the transformation: + - `extend_axis`: Extends the axis by adding an additional element to ensure zero-sum constraint. + - `extend_axis_rev`: Reverses the extension operation applied by `extend_axis`. + """ name = "zerosum" From d05980a5b3ca6966cdb382b691d3bfed055a1b9c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 11:16:58 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc/distributions/transforms.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 3466436a98..62a9d4c375 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -53,20 +53,15 @@ def __getattr__(name): if name in ("univariate_ordered", "multivariate_ordered"): - warnings.warn( - f"{name} has been deprecated, use ordered instead.", - FutureWarning) + warnings.warn(f"{name} has been deprecated, use ordered instead.", FutureWarning) return ordered if name in ("univariate_sum_to_1", "multivariate_sum_to_1"): - warnings.warn( - f"{name} has been deprecated, use sum_to_1 instead.", - FutureWarning) + warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning) return sum_to_1 if name == "RVTransform": - warnings.warn( - "RVTransform has been renamed to Transform", FutureWarning) + warnings.warn("RVTransform has been renamed to Transform", FutureWarning) return Transform raise AttributeError(f"module {__name__} has no attribute {name}") @@ -101,9 +96,7 @@ class Ordered(Transform): def __init__(self, ndim_supp=None): if ndim_supp is not None: - warnings.warn( - "ndim_supp argument is deprecated and has no effect", - FutureWarning) + warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) def backward(self, value, *inputs): x = pt.zeros(value.shape) @@ -133,9 +126,7 @@ class SumTo1(Transform): def __init__(self, ndim_supp=None): if ndim_supp is not None: - warnings.warn( - "ndim_supp argument is deprecated and has no effect", - FutureWarning) + warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) def backward(self, value, *inputs): remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True)