diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0aa40110d5..4f66d9333b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -68,6 +68,7 @@ jobs: - | tests/distributions/test_censored.py + tests/distributions/test_custom.py tests/distributions/test_simulator.py tests/sampling/test_deterministic.py tests/sampling/test_forward.py diff --git a/docs/source/api/distributions.rst b/docs/source/api/distributions.rst index ca3a09cbaa..3384b8157d 100644 --- a/docs/source/api/distributions.rst +++ b/docs/source/api/distributions.rst @@ -14,6 +14,7 @@ Distributions distributions/timeseries distributions/truncated distributions/censored + distributions/custom distributions/simulator distributions/transforms distributions/utilities diff --git a/docs/source/api/distributions/custom.rst b/docs/source/api/distributions/custom.rst new file mode 100644 index 0000000000..a2e95c909a --- /dev/null +++ b/docs/source/api/distributions/custom.rst @@ -0,0 +1,20 @@ +********** +CustomDist +********** + +.. + Manually follow the template in _templates/distribution.rst. + If at any point, multiple objects are listed here, + the pattern should instead be modified to that of the + other API files such as api/distributions/continuous.rst + +.. currentmodule:: pymc + +.. autoclass:: CustomDist + + .. rubric:: Methods + + .. autosummary:: + :toctree: classmethods + + CustomDist.dist diff --git a/docs/source/api/distributions/utilities.rst b/docs/source/api/distributions/utilities.rst index 39a2a46f97..61cfb1a314 100644 --- a/docs/source/api/distributions/utilities.rst +++ b/docs/source/api/distributions/utilities.rst @@ -7,7 +7,6 @@ Distribution utilities :toctree: generated/ Continuous - CustomDist Discrete Distribution SymbolicRandomVariable diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index 51f075de54..d5b23bfbaf 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -50,6 +50,7 @@ Wald, Weibull, ) +from pymc.distributions.custom import CustomDist, DensityDist from pymc.distributions.discrete import ( Bernoulli, BetaBinomial, @@ -66,8 +67,6 @@ ) from pymc.distributions.distribution import ( Continuous, - CustomDist, - DensityDist, DiracDelta, Discrete, Distribution, diff --git a/pymc/distributions/custom.py b/pymc/distributions/custom.py new file mode 100644 index 0000000000..3f8882bc16 --- /dev/null +++ b/pymc/distributions/custom.py @@ -0,0 +1,841 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import re +import warnings + +from collections.abc import Callable, Sequence + +from pytensor import Variable, clone_replace +from pytensor import tensor as pt +from pytensor.graph.basic import io_toposort +from pytensor.graph.features import ReplaceValidate +from pytensor.graph.rewriting.basic import GraphRewriter +from pytensor.scan.op import Scan +from pytensor.tensor import TensorVariable, as_tensor_variable +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.type import RandomGeneratorType, RandomType +from pytensor.tensor.random.utils import normalize_size_param +from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature + +from pymc.distributions.distribution import ( + Distribution, + SymbolicRandomVariable, + _support_point, + support_point, +) +from pymc.distributions.shape_utils import _change_dist_size, rv_size_is_none +from pymc.exceptions import BlockModelAccessError +from pymc.logprob.abstract import _logcdf, _logprob +from pymc.model.core import new_or_existing_block_model_access +from pymc.pytensorf import collect_default_updates + + +def default_not_implemented(rv_name, method_name): + message = ( + f"Attempted to run {method_name} on the CustomDist '{rv_name}', " + f"but this method had not been provided when the distribution was " + f"constructed. Please re-build your model and provide a callable " + f"to '{rv_name}'s {method_name} keyword argument.\n" + ) + + def func(*args, **kwargs): + raise NotImplementedError(message) + + return func + + +def default_support_point(rv, size, *rv_inputs, rv_name=None, has_fallback=False): + if None not in rv.type.shape: + return pt.zeros(rv.type.shape) + elif rv.owner.op.ndim_supp == 0 and not rv_size_is_none(size): + return pt.zeros(size) + elif has_fallback: + return pt.zeros_like(rv) + else: + raise TypeError( + "Cannot safely infer the size of a multivariate random variable's support_point. " + f"Please provide a support_point function when instantiating the {rv_name} " + "random variable." + ) + + +class CustomDistRV(RandomVariable): + """ + Base class for CustomDistRV + + This should be subclassed when defining CustomDist objects. + """ + + name = "CustomDistRV" + _print_name = ("CustomDist", "\\operatorname{CustomDist}") + + @classmethod + def rng_fn(cls, rng, *args): + args = list(args) + size = args.pop(-1) + return cls._random_fn(*args, rng=rng, size=size) + + +class _CustomDist(Distribution): + """A distribution that returns a subclass of CustomDistRV""" + + rv_type = CustomDistRV + + @classmethod + def dist( + cls, + *dist_params, + logp: Callable | None = None, + logcdf: Callable | None = None, + random: Callable | None = None, + support_point: Callable | None = None, + ndim_supp: int | None = None, + ndims_params: Sequence[int] | None = None, + signature: str | None = None, + dtype: str = "floatX", + class_name: str = "CustomDist", + **kwargs, + ): + if ndim_supp is None or ndims_params is None: + if signature is None: + ndim_supp = 0 + ndims_params = [0] * len(dist_params) + else: + inputs, outputs = _parse_gufunc_signature(signature) + ndim_supp = max(len(out) for out in outputs) + ndims_params = [len(inp) for inp in inputs] + + if ndim_supp > 0: + raise NotImplementedError( + "CustomDist with ndim_supp > 0 and without a `dist` function are not supported." + ) + + dist_params = [as_tensor_variable(param) for param in dist_params] + + if logp is None: + logp = default_not_implemented(class_name, "logp") + + if logcdf is None: + logcdf = default_not_implemented(class_name, "logcdf") + + if support_point is None: + support_point = functools.partial( + default_support_point, + rv_name=class_name, + has_fallback=random is not None, + ) + + if random is None: + random = default_not_implemented(class_name, "random") + + return super().dist( + dist_params, + logp=logp, + logcdf=logcdf, + random=random, + support_point=support_point, + ndim_supp=ndim_supp, + ndims_params=ndims_params, + dtype=dtype, + class_name=class_name, + **kwargs, + ) + + @classmethod + def rv_op( + cls, + *dist_params, + logp: Callable | None, + logcdf: Callable | None, + random: Callable | None, + support_point: Callable | None, + ndim_supp: int, + ndims_params: Sequence[int], + dtype: str, + class_name: str, + **kwargs, + ): + rv_type = type( + class_name, + (CustomDistRV,), + dict( + name=class_name, + inplace=False, + ndim_supp=ndim_supp, + ndims_params=ndims_params, + dtype=dtype, + _print_name=(class_name, f"\\operatorname{{{class_name}}}"), + # Specific to CustomDist + _random_fn=random, + ), + ) + + # Dispatch custom methods + @_logprob.register(rv_type) + def custom_dist_logp(op, values, rng, size, *dist_params, **kwargs): + return logp(values[0], *dist_params) + + @_logcdf.register(rv_type) + def custom_dist_logcdf(op, value, rng, size, *dist_params, **kwargs): + return logcdf(value, *dist_params, **kwargs) + + @_support_point.register(rv_type) + def custom_dist_support_point(op, rv, rng, size, *dist_params): + return support_point(rv, size, *dist_params) + + rv_op = rv_type() + return rv_op(*dist_params, **kwargs) + + +class CustomSymbolicDistRV(SymbolicRandomVariable): + """ + Base class for CustomSymbolicDist + + This should be subclassed when defining custom CustomDist objects that have + symbolic random methods. + """ + + default_output = 0 + + _print_name = ("CustomSymbolicDist", "\\operatorname{CustomSymbolicDist}") + + +class _CustomSymbolicDist(Distribution): + rv_type = CustomSymbolicDistRV + + @classmethod + def dist( + cls, + *dist_params, + dist: Callable, + logp: Callable | None = None, + logcdf: Callable | None = None, + support_point: Callable | None = None, + ndim_supp: int | None = None, + ndims_params: Sequence[int] | None = None, + signature: str | None = None, + dtype: str = "floatX", + class_name: str = "CustomDist", + **kwargs, + ): + dist_params = [as_tensor_variable(param) for param in dist_params] + + if logcdf is None: + logcdf = default_not_implemented(class_name, "logcdf") + + if signature is None: + if ndim_supp is None: + ndim_supp = 0 + if ndims_params is None: + ndims_params = [0] * len(dist_params) + signature = safe_signature( + core_inputs_ndim=ndims_params, + core_outputs_ndim=[ndim_supp], + ) + + return super().dist( + dist_params, + class_name=class_name, + logp=logp, + logcdf=logcdf, + dist=dist, + support_point=support_point, + signature=signature, + **kwargs, + ) + + @classmethod + def rv_op( + cls, + *dist_params, + dist: Callable, + logp: Callable | None, + logcdf: Callable | None, + support_point: Callable | None, + size=None, + signature: str, + class_name: str, + ): + size = normalize_size_param(size) + # If it's NoneConst, just use that as the dummy + dummy_size_param = size.type() if isinstance(size, TensorVariable) else size + dummy_dist_params = [dist_param.type() for dist_param in dist_params] + with new_or_existing_block_model_access( + error_msg_on_access="Model variables cannot be created in the dist function. Use the `.dist` API" + ): + dummy_rv = dist(*dummy_dist_params, dummy_size_param) + dummy_params = [dummy_size_param, *dummy_dist_params] + # RNGs are not passed as explicit inputs (because we usually don't know how many are needed) + # We retrieve them here. This will also raise if the user forgot to specify some update in a Scan Op + dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) + + rv_type = type( + class_name, + (CustomSymbolicDistRV,), + # If logp is not provided, we try to infer it from the dist graph + dict( + inline_logprob=logp is None, + _print_name=(class_name, f"\\operatorname{{{class_name}}}"), + ), + ) + + # Dispatch custom methods + if logp is not None: + + @_logprob.register(rv_type) + def custom_dist_logp(op, values, size, *inputs, **kwargs): + [value] = values + rv_params = inputs[: len(dist_params)] + return logp(value, *rv_params) + + if logcdf is not None: + + @_logcdf.register(rv_type) + def custom_dist_logcdf(op, value, size, *inputs, **kwargs): + rv_params = inputs[: len(dist_params)] + return logcdf(value, *rv_params) + + if support_point is not None: + + @_support_point.register(rv_type) + def custom_dist_support_point(op, rv, size, *params): + return support_point( + rv, + size, + *[ + p + for p in params + if not isinstance(p.type, RandomType | RandomGeneratorType) + ], + ) + + @_change_dist_size.register(rv_type) + def change_custom_dist_size(op, rv, new_size, expand): + node = rv.owner + + if expand: + shape = tuple(rv.shape) + old_size = shape[: len(shape) - node.op.ndim_supp] + new_size = tuple(new_size) + tuple(old_size) + new_size = pt.as_tensor(new_size, dtype="int64", ndim=1) + + old_size, *old_dist_params = node.inputs[: len(dist_params) + 1] + + # OpFromGraph has to be recreated if the size type changes! + dummy_size_param = new_size.type() + dummy_dist_params = [dist_param.type() for dist_param in old_dist_params] + dummy_rv = dist(*dummy_dist_params, dummy_size_param) + dummy_params = [dummy_size_param, *dummy_dist_params] + updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) + rngs = updates_dict.keys() + rngs_updates = updates_dict.values() + new_rv_op = rv_type( + inputs=[*dummy_params, *rngs], + outputs=[dummy_rv, *rngs_updates], + signature=signature, + ) + new_rv = new_rv_op(new_size, *dist_params, *rngs) + + return new_rv + + if dummy_updates_dict: + rngs, rngs_updates = zip(*dummy_updates_dict.items()) + else: + rngs, rngs_updates = (), () + + inputs = [*dummy_params, *rngs] + outputs = [dummy_rv, *rngs_updates] + signature = cls._infer_final_signature( + signature, n_inputs=len(inputs), n_outputs=len(outputs), n_rngs=len(rngs) + ) + rv_op = rv_type( + inputs=inputs, + outputs=outputs, + signature=signature, + ) + return rv_op(size, *dist_params, *rngs) + + @staticmethod + def _infer_final_signature(signature: str, n_inputs, n_outputs, n_rngs) -> str: + """Add size and updates to user provided gufunc signature if they are missing.""" + + # Regex to split across outer commas + # Copied from https://stackoverflow.com/a/26634150 + outer_commas = re.compile(r",\s*(?![^()]*\))") + + input_sig, output_sig = signature.split("->") + # It's valid to have a signature without params inputs, as in a Flat RV + n_inputs_sig = len(outer_commas.split(input_sig)) if input_sig.strip() else 0 + n_outputs_sig = len(outer_commas.split(output_sig)) + + if n_inputs_sig == n_inputs and n_outputs_sig == n_outputs: + # User provided a signature with no missing parts + return signature + + size_sig = "[size]" + rngs_sig = ("[rng]",) * n_rngs + if n_inputs_sig == (n_inputs - n_rngs - 1): + # Assume size and rngs are missing + if input_sig.strip(): + input_sig = ",".join((size_sig, input_sig, *rngs_sig)) + else: + input_sig = ",".join((size_sig, *rngs_sig)) + if n_outputs_sig == (n_outputs - n_rngs): + # Assume updates are missing + output_sig = ",".join((output_sig, *rngs_sig)) + signature = "->".join((input_sig, output_sig)) + return signature + + +class SupportPointRewrite(GraphRewriter): + def rewrite_support_point_scan_node(self, node): + if not isinstance(node.op, Scan): + return + + node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs + op = node.op + + local_fgraph_topo = io_toposort(node_inputs, node_outputs) + + replace_with_support_point = [] + to_replace_set = set() + + for nd in local_fgraph_topo: + if nd not in to_replace_set and isinstance( + nd.op, RandomVariable | SymbolicRandomVariable + ): + replace_with_support_point.append(nd.out) + to_replace_set.add(nd) + givens = {} + if len(replace_with_support_point) > 0: + for item in replace_with_support_point: + givens[item] = support_point(item) + else: + return + op_outs = clone_replace(node_outputs, replace=givens) + + nwScan = Scan( + node_inputs, + op_outs, + op.info, + mode=op.mode, + profile=op.profile, + truncate_gradient=op.truncate_gradient, + name=op.name, + allow_gc=op.allow_gc, + ) + nw_node = nwScan(*(node.inputs), return_list=True)[0].owner + return nw_node + + def add_requirements(self, fgraph): + fgraph.attach_feature(ReplaceValidate()) + + def apply(self, fgraph): + for node in fgraph.toposort(): + if isinstance(node.op, RandomVariable | SymbolicRandomVariable): + fgraph.replace(node.out, support_point(node.out)) + elif isinstance(node.op, Scan): + new_node = self.rewrite_support_point_scan_node(node) + if new_node is not None: + fgraph.replace_all(tuple(zip(node.outputs, new_node.outputs))) + + +@_support_point.register(CustomSymbolicDistRV) +def dist_support_point(op, rv, *args): + node = rv.owner + rv_out_idx = node.outputs.index(rv) + + fgraph = op.fgraph.clone() + replace_support_point = SupportPointRewrite() + replace_support_point.rewrite(fgraph) + # Replace dummy inner inputs by outer inputs + fgraph.replace_all(tuple(zip(op.inner_inputs, args)), import_missing=True) + support_point = fgraph.outputs[rv_out_idx] + return support_point + + +class CustomDist: + """A helper class to create custom distributions + + This class can be used to wrap black-box random and logp methods for use in + forward and mcmc sampling. + + A user can provide a `dist` function that returns a PyTensor graph built from + simpler PyMC distributions, which represents the distribution. This graph is + used to take random draws, and to infer the logp expression automatically + when not provided by the user. + + Alternatively, a user can provide a `random` function that returns numerical + draws (e.g., via NumPy routines), and a `logp` function that must return a + PyTensor graph that represents the logp graph when evaluated. This is used for + mcmc sampling. + + Additionally, a user can provide a `logcdf` and `support_point` functions that must return + PyTensor graphs that computes those quantities. These may be used by other PyMC + routines. + + Parameters + ---------- + name : str + dist_params : Tuple + A sequence of the distribution's parameter. These will be converted into + Pytensor tensor variables internally. + dist: Optional[Callable] + A callable that returns a PyTensor graph built from simpler PyMC distributions + which represents the distribution. This can be used by PyMC to take random draws + as well as to infer the logp of the distribution in some cases. In that case + it's not necessary to implement ``random`` or ``logp`` functions. + + It must have the following signature: ``dist(*dist_params, size)``. + The symbolic tensor distribution parameters are passed as positional arguments in + the same order as they are supplied when the ``CustomDist`` is constructed. + + random : Optional[Callable] + A callable that can be used to generate random draws from the distribution + + It must have the following signature: ``random(*dist_params, rng=None, size=None)``. + The numerical distribution parameters are passed as positional arguments in the + same order as they are supplied when the ``CustomDist`` is constructed. + The keyword arguments are ``rng``, which will provide the random variable's + associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent + the desired size of the random draw. If ``None``, a ``NotImplemented`` + error will be raised when trying to draw random samples from the distribution's + prior or posterior predictive. + + logp : Optional[Callable] + A callable that calculates the log probability of some given ``value`` + conditioned on certain distribution parameter values. It must have the + following signature: ``logp(value, *dist_params)``, where ``value`` is + an PyTensor tensor that represents the distribution value, and ``dist_params`` + are the tensors that hold the values of the distribution parameters. + This function must return an PyTensor tensor. + + When the `dist` function is specified, PyMC will try to automatically + infer the `logp` when this is not provided. + + Otherwise, a ``NotImplementedError`` will be raised when trying to compute the + distribution's logp. + logcdf : Optional[Callable] + A callable that calculates the log cumulative log probability of some given + ``value`` conditioned on certain distribution parameter values. It must have the + following signature: ``logcdf(value, *dist_params)``, where ``value`` is + an PyTensor tensor that represents the distribution value, and ``dist_params`` + are the tensors that hold the values of the distribution parameters. + This function must return an PyTensor tensor. If ``None``, a ``NotImplementedError`` + will be raised when trying to compute the distribution's logcdf. + support_point : Optional[Callable] + A callable that can be used to compute the finete logp point of the distribution. + It must have the following signature: ``support_point(rv, size, *rv_inputs)``. + The distribution's variable is passed as the first argument ``rv``. ``size`` + is the random variable's size implied by the ``dims``, ``size`` and parameters + supplied to the distribution. Finally, ``rv_inputs`` is the sequence of the + distribution parameters, in the same order as they were supplied when the + CustomDist was created. If ``None``, a default ``support_point`` function will be + assigned that will always return 0, or an array of zeros. + ndim_supp : Optional[int] + The number of dimensions in the support of the distribution. + Inferred from signature, if provided. Defaults to assuming + a scalar distribution, i.e. ``ndim_supp = 0`` + ndims_params : Optional[Sequence[int]] + The list of number of dimensions in the support of each of the distribution's + parameters. Inferred from signature, if provided. Defaults to assuming + all parameters are scalars, i.e. ``ndims_params=[0, ...]``. + signature : Optional[str] + A numpy vectorize-like signature that indicates the number and core dimensionality + of the input parameters and sample outputs of the CustomDist. + When specified, `ndim_supp` and `ndims_params` are not needed. See examples below. + dtype : str + The dtype of the distribution. All draws and observations passed into the + distribution will be cast onto this dtype. This is not needed if an PyTensor + dist function is provided, which should already return the right dtype! + class_name : str + Name for the class which will wrap the CustomDist methods. When not specified, + it will be given the name of the model variable. + kwargs : + Extra keyword arguments are passed to the parent's class ``__new__`` method. + + + Examples + -------- + Create a CustomDist that wraps a black-box logp function. This variable cannot be + used in prior or posterior predictive sampling because no random function was provided + + .. code-block:: python + + import numpy as np + import pymc as pm + from pytensor.tensor import TensorVariable + + def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable: + return -(value - mu)**2 + + with pm.Model(): + mu = pm.Normal('mu',0,1) + pm.CustomDist( + 'custom_dist', + mu, + logp=logp, + observed=np.random.randn(100), + ) + idata = pm.sample(100) + + Provide a random function that return numerical draws. This allows one to use a + CustomDist in prior and posterior predictive sampling. + A gufunc signature was also provided, which may be used by other routines. + + .. code-block:: python + + from typing import Optional, Tuple + + import numpy as np + import pymc as pm + from pytensor.tensor import TensorVariable + + def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable: + return -(value - mu)**2 + + def random( + mu: np.ndarray | float, + rng: Optional[np.random.Generator] = None, + size : Optional[Tuple[int]]=None, + ) -> np.ndarray | float : + return rng.normal(loc=mu, scale=1, size=size) + + with pm.Model(): + mu = pm.Normal('mu', 0 , 1) + pm.CustomDist( + 'custom_dist', + mu, + logp=logp, + random=random, + signature="()->()", + observed=np.random.randn(100, 3), + size=(100, 3), + ) + prior = pm.sample_prior_predictive(10) + + Provide a dist function that creates a PyTensor graph built from other + PyMC distributions. PyMC can automatically infer that the logp of this + variable corresponds to a shifted Exponential distribution. + A gufunc signature was also provided, which may be used by other routines. + + .. code-block:: python + + import pymc as pm + from pytensor.tensor import TensorVariable + + def dist( + lam: TensorVariable, + shift: TensorVariable, + size: TensorVariable, + ) -> TensorVariable: + return pm.Exponential.dist(lam, size=size) + shift + + with pm.Model() as m: + lam = pm.HalfNormal("lam") + shift = -1 + pm.CustomDist( + "custom_dist", + lam, + shift, + dist=dist, + signature="(),()->()", + observed=[-1, -1, 0], + ) + + prior = pm.sample_prior_predictive() + posterior = pm.sample() + + Provide a dist function that creates a PyTensor graph built from other + PyMC distributions. PyMC can automatically infer that the logp of + this variable corresponds to a modified-PERT distribution. + + .. code-block:: python + + import pymc as pm + from pytensor.tensor import TensorVariable + + def pert( + low: TensorVariable, + peak: TensorVariable, + high: TensorVariable, + lmbda: TensorVariable, + size: TensorVariable, + ) -> TensorVariable: + range = (high - low) + s_alpha = 1 + lmbda * (peak - low) / range + s_beta = 1 + lmbda * (high - peak) / range + return pm.Beta.dist(s_alpha, s_beta, size=size) * range + low + + with pm.Model() as m: + low = pm.Normal("low", 0, 10) + peak = pm.Normal("peak", 50, 10) + high = pm.Normal("high", 100, 10) + lmbda = 4 + pm.CustomDist("pert", low, peak, high, lmbda, dist=pert, observed=[30, 35, 73]) + + m.point_logps() + + """ + + def __new__( + cls, + name, + *dist_params, + dist: Callable | None = None, + random: Callable | None = None, + logp: Callable | None = None, + logcdf: Callable | None = None, + support_point: Callable | None = None, + # TODO: Deprecate ndim_supp / ndims_params in favor of signature? + ndim_supp: int | None = None, + ndims_params: Sequence[int] | None = None, + signature: str | None = None, + dtype: str = "floatX", + **kwargs, + ): + if isinstance(kwargs.get("observed", None), dict): + raise TypeError( + "Since ``v4.0.0`` the ``observed`` parameter should be of type" + " ``pd.Series``, ``np.array``, or ``pm.Data``." + " Previous versions allowed passing distribution parameters as" + " a dictionary in ``observed``, in the current version these " + "parameters are positional arguments." + ) + dist_params = cls.parse_dist_params(dist_params) + cls.check_valid_dist_random(dist, random, dist_params) + moment = kwargs.pop("moment", None) + if moment is not None: + warnings.warn( + "`moment` argument is deprecated. Use `support_point` instead.", + FutureWarning, + ) + support_point = moment + if dist is not None: + kwargs.setdefault("class_name", f"CustomDist_{name}") + return _CustomSymbolicDist( + name, + *dist_params, + dist=dist, + logp=logp, + logcdf=logcdf, + support_point=support_point, + ndim_supp=ndim_supp, + ndims_params=ndims_params, + signature=signature, + **kwargs, + ) + else: + kwargs.setdefault("class_name", f"CustomDist_{name}") + return _CustomDist( + name, + *dist_params, + random=random, + logp=logp, + logcdf=logcdf, + support_point=support_point, + ndim_supp=ndim_supp, + ndims_params=ndims_params, + signature=signature, + dtype=dtype, + **kwargs, + ) + + @classmethod + def dist( + cls, + *dist_params, + dist: Callable | None = None, + random: Callable | None = None, + logp: Callable | None = None, + logcdf: Callable | None = None, + support_point: Callable | None = None, + ndim_supp: int | None = None, + ndims_params: Sequence[int] | None = None, + signature: str | None = None, + dtype: str = "floatX", + **kwargs, + ): + dist_params = cls.parse_dist_params(dist_params) + cls.check_valid_dist_random(dist, random, dist_params) + if dist is not None: + return _CustomSymbolicDist.dist( + *dist_params, + dist=dist, + logp=logp, + logcdf=logcdf, + support_point=support_point, + ndim_supp=ndim_supp, + ndims_params=ndims_params, + signature=signature, + **kwargs, + ) + else: + return _CustomDist.dist( + *dist_params, + random=random, + logp=logp, + logcdf=logcdf, + support_point=support_point, + ndim_supp=ndim_supp, + ndims_params=ndims_params, + signature=signature, + dtype=dtype, + **kwargs, + ) + + @classmethod + def parse_dist_params(cls, dist_params): + if len(dist_params) > 0 and callable(dist_params[0]): + raise TypeError( + "The DensityDist API has changed, you are using the old API " + "where logp was the first positional argument. In the current API, " + "the logp is a keyword argument, amongst other changes. Please refer " + "to the API documentation for more information on how to use the " + "new DensityDist API." + ) + return [as_tensor_variable(param) for param in dist_params] + + @classmethod + def check_valid_dist_random(cls, dist, random, dist_params): + if dist is not None and random is not None: + raise ValueError("Cannot provide both dist and random functions") + if random is not None and cls.is_symbolic_random(random, dist_params): + raise TypeError( + "API change: function passed to `random` argument should no longer return a PyTensor graph. " + "Pass such function to the `dist` argument instead." + ) + + @classmethod + def is_symbolic_random(self, random, dist_params): + if random is None: + return False + # Try calling random with symbolic inputs + try: + size = normalize_size_param(None) + with new_or_existing_block_model_access( + error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API to create such variables." + ): + out = random(*dist_params, size) + except BlockModelAccessError: + raise + except Exception: + # If it fails we assume it was not + return False + # Confirm the output is symbolic + return isinstance(out, Variable) + + +DensityDist = CustomDist diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index e5c2e68684..d71dda97c8 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -28,16 +28,12 @@ from pytensor import tensor as pt from pytensor.compile.builders import OpFromGraph from pytensor.graph import FunctionGraph, clone_replace, node_rewriter -from pytensor.graph.basic import Apply, Variable, io_toposort -from pytensor.graph.features import ReplaceValidate -from pytensor.graph.rewriting.basic import GraphRewriter, in2out +from pytensor.graph.basic import Apply, Variable +from pytensor.graph.rewriting.basic import in2out from pytensor.graph.utils import MetaType -from pytensor.scan.op import Scan from pytensor.tensor.basic import as_tensor_variable -from pytensor.tensor.blockwise import safe_signature from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.rewriting import local_subtensor_rv_lift -from pytensor.tensor.random.type import RandomGeneratorType, RandomType from pytensor.tensor.random.utils import normalize_size_param from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.utils import _parse_gufunc_signature @@ -54,14 +50,11 @@ rv_size_is_none, shape_from_dims, ) -from pymc.exceptions import BlockModelAccessError from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob from pymc.logprob.basic import logp from pymc.logprob.rewriting import logprob_rewrites_db -from pymc.model.core import new_or_existing_block_model_access from pymc.printing import str_for_dist from pymc.pytensorf import ( - collect_default_updates, collect_default_updates_inner_fgraph, constant_fold, convert_observed_data, @@ -71,8 +64,6 @@ from pymc.vartypes import continuous_types, string_types __all__ = [ - "CustomDist", - "DensityDist", "DiracDelta", "Distribution", "Continuous", @@ -89,59 +80,6 @@ PLATFORM = sys.platform -class FiniteLogpPointRewrite(GraphRewriter): - def rewrite_support_point_scan_node(self, node): - if not isinstance(node.op, Scan): - return - - node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs - op = node.op - - local_fgraph_topo = io_toposort(node_inputs, node_outputs) - - replace_with_support_point = [] - to_replace_set = set() - - for nd in local_fgraph_topo: - if nd not in to_replace_set and isinstance( - nd.op, RandomVariable | SymbolicRandomVariable - ): - replace_with_support_point.append(nd.out) - to_replace_set.add(nd) - givens = {} - if len(replace_with_support_point) > 0: - for item in replace_with_support_point: - givens[item] = support_point(item) - else: - return - op_outs = clone_replace(node_outputs, replace=givens) - - nwScan = Scan( - node_inputs, - op_outs, - op.info, - mode=op.mode, - profile=op.profile, - truncate_gradient=op.truncate_gradient, - name=op.name, - allow_gc=op.allow_gc, - ) - nw_node = nwScan(*(node.inputs), return_list=True)[0].owner - return nw_node - - def add_requirements(self, fgraph): - fgraph.attach_feature(ReplaceValidate()) - - def apply(self, fgraph): - for node in fgraph.toposort(): - if isinstance(node.op, RandomVariable | SymbolicRandomVariable): - fgraph.replace(node.out, support_point(node.out)) - elif isinstance(node.op, Scan): - new_node = self.rewrite_support_point_scan_node(node) - if new_node is not None: - fgraph.replace_all(tuple(zip(node.outputs, new_node.outputs))) - - class _Unpickling: pass @@ -755,752 +693,6 @@ class Continuous(Distribution): """Base class for continuous distributions""" -class CustomDistRV(RandomVariable): - """ - Base class for CustomDistRV - - This should be subclassed when defining CustomDist objects. - """ - - name = "CustomDistRV" - _print_name = ("CustomDist", "\\operatorname{CustomDist}") - - @classmethod - def rng_fn(cls, rng, *args): - args = list(args) - size = args.pop(-1) - return cls._random_fn(*args, rng=rng, size=size) - - -class _CustomDist(Distribution): - """A distribution that returns a subclass of CustomDistRV""" - - rv_type = CustomDistRV - - @classmethod - def dist( - cls, - *dist_params, - logp: Callable | None = None, - logcdf: Callable | None = None, - random: Callable | None = None, - support_point: Callable | None = None, - ndim_supp: int | None = None, - ndims_params: Sequence[int] | None = None, - signature: str | None = None, - dtype: str = "floatX", - class_name: str = "CustomDist", - **kwargs, - ): - if ndim_supp is None or ndims_params is None: - if signature is None: - ndim_supp = 0 - ndims_params = [0] * len(dist_params) - else: - inputs, outputs = _parse_gufunc_signature(signature) - ndim_supp = max(len(out) for out in outputs) - ndims_params = [len(inp) for inp in inputs] - - if ndim_supp > 0: - raise NotImplementedError( - "CustomDist with ndim_supp > 0 and without a `dist` function are not supported." - ) - - dist_params = [as_tensor_variable(param) for param in dist_params] - - if logp is None: - logp = default_not_implemented(class_name, "logp") - - if logcdf is None: - logcdf = default_not_implemented(class_name, "logcdf") - - if support_point is None: - support_point = functools.partial( - default_support_point, - rv_name=class_name, - has_fallback=random is not None, - ) - - if random is None: - random = default_not_implemented(class_name, "random") - - return super().dist( - dist_params, - logp=logp, - logcdf=logcdf, - random=random, - support_point=support_point, - ndim_supp=ndim_supp, - ndims_params=ndims_params, - dtype=dtype, - class_name=class_name, - **kwargs, - ) - - @classmethod - def rv_op( - cls, - *dist_params, - logp: Callable | None, - logcdf: Callable | None, - random: Callable | None, - support_point: Callable | None, - ndim_supp: int, - ndims_params: Sequence[int], - dtype: str, - class_name: str, - **kwargs, - ): - rv_type = type( - class_name, - (CustomDistRV,), - dict( - name=class_name, - inplace=False, - ndim_supp=ndim_supp, - ndims_params=ndims_params, - dtype=dtype, - _print_name=(class_name, f"\\operatorname{{{class_name}}}"), - # Specific to CustomDist - _random_fn=random, - ), - ) - - # Dispatch custom methods - @_logprob.register(rv_type) - def custom_dist_logp(op, values, rng, size, *dist_params, **kwargs): - return logp(values[0], *dist_params) - - @_logcdf.register(rv_type) - def custom_dist_logcdf(op, value, rng, size, *dist_params, **kwargs): - return logcdf(value, *dist_params, **kwargs) - - @_support_point.register(rv_type) - def custom_dist_support_point(op, rv, rng, size, *dist_params): - return support_point(rv, size, *dist_params) - - rv_op = rv_type() - return rv_op(*dist_params, **kwargs) - - -class CustomSymbolicDistRV(SymbolicRandomVariable): - """ - Base class for CustomSymbolicDist - - This should be subclassed when defining custom CustomDist objects that have - symbolic random methods. - """ - - default_output = 0 - - _print_name = ("CustomSymbolicDist", "\\operatorname{CustomSymbolicDist}") - - -@_support_point.register(CustomSymbolicDistRV) -def dist_support_point(op, rv, *args): - node = rv.owner - rv_out_idx = node.outputs.index(rv) - - fgraph = op.fgraph.clone() - replace_support_point = FiniteLogpPointRewrite() - replace_support_point.rewrite(fgraph) - # Replace dummy inner inputs by outer inputs - fgraph.replace_all(tuple(zip(op.inner_inputs, args)), import_missing=True) - support_point = fgraph.outputs[rv_out_idx] - return support_point - - -class _CustomSymbolicDist(Distribution): - rv_type = CustomSymbolicDistRV - - @classmethod - def dist( - cls, - *dist_params, - dist: Callable, - logp: Callable | None = None, - logcdf: Callable | None = None, - support_point: Callable | None = None, - ndim_supp: int | None = None, - ndims_params: Sequence[int] | None = None, - signature: str | None = None, - dtype: str = "floatX", - class_name: str = "CustomDist", - **kwargs, - ): - dist_params = [as_tensor_variable(param) for param in dist_params] - - if logcdf is None: - logcdf = default_not_implemented(class_name, "logcdf") - - if signature is None: - if ndim_supp is None: - ndim_supp = 0 - if ndims_params is None: - ndims_params = [0] * len(dist_params) - signature = safe_signature( - core_inputs_ndim=ndims_params, - core_outputs_ndim=[ndim_supp], - ) - - return super().dist( - dist_params, - class_name=class_name, - logp=logp, - logcdf=logcdf, - dist=dist, - support_point=support_point, - signature=signature, - **kwargs, - ) - - @classmethod - def rv_op( - cls, - *dist_params, - dist: Callable, - logp: Callable | None, - logcdf: Callable | None, - support_point: Callable | None, - size=None, - signature: str, - class_name: str, - ): - size = normalize_size_param(size) - # If it's NoneConst, just use that as the dummy - dummy_size_param = size.type() if isinstance(size, TensorVariable) else size - dummy_dist_params = [dist_param.type() for dist_param in dist_params] - with new_or_existing_block_model_access( - error_msg_on_access="Model variables cannot be created in the dist function. Use the `.dist` API" - ): - dummy_rv = dist(*dummy_dist_params, dummy_size_param) - dummy_params = [dummy_size_param, *dummy_dist_params] - # RNGs are not passed as explicit inputs (because we usually don't know how many are needed) - # We retrieve them here. This will also raise if the user forgot to specify some update in a Scan Op - dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) - - rv_type = type( - class_name, - (CustomSymbolicDistRV,), - # If logp is not provided, we try to infer it from the dist graph - dict( - inline_logprob=logp is None, - _print_name=(class_name, f"\\operatorname{{{class_name}}}"), - ), - ) - - # Dispatch custom methods - if logp is not None: - - @_logprob.register(rv_type) - def custom_dist_logp(op, values, size, *inputs, **kwargs): - [value] = values - rv_params = inputs[: len(dist_params)] - return logp(value, *rv_params) - - if logcdf is not None: - - @_logcdf.register(rv_type) - def custom_dist_logcdf(op, value, size, *inputs, **kwargs): - rv_params = inputs[: len(dist_params)] - return logcdf(value, *rv_params) - - if support_point is not None: - - @_support_point.register(rv_type) - def custom_dist_support_point(op, rv, size, *params): - return support_point( - rv, - size, - *[ - p - for p in params - if not isinstance(p.type, RandomType | RandomGeneratorType) - ], - ) - - @_change_dist_size.register(rv_type) - def change_custom_dist_size(op, rv, new_size, expand): - node = rv.owner - - if expand: - shape = tuple(rv.shape) - old_size = shape[: len(shape) - node.op.ndim_supp] - new_size = tuple(new_size) + tuple(old_size) - new_size = pt.as_tensor(new_size, dtype="int64", ndim=1) - - old_size, *old_dist_params = node.inputs[: len(dist_params) + 1] - - # OpFromGraph has to be recreated if the size type changes! - dummy_size_param = new_size.type() - dummy_dist_params = [dist_param.type() for dist_param in old_dist_params] - dummy_rv = dist(*dummy_dist_params, dummy_size_param) - dummy_params = [dummy_size_param, *dummy_dist_params] - updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) - rngs = updates_dict.keys() - rngs_updates = updates_dict.values() - new_rv_op = rv_type( - inputs=[*dummy_params, *rngs], - outputs=[dummy_rv, *rngs_updates], - signature=signature, - ) - new_rv = new_rv_op(new_size, *dist_params, *rngs) - - return new_rv - - if dummy_updates_dict: - rngs, rngs_updates = zip(*dummy_updates_dict.items()) - else: - rngs, rngs_updates = (), () - - inputs = [*dummy_params, *rngs] - outputs = [dummy_rv, *rngs_updates] - signature = cls._infer_final_signature( - signature, n_inputs=len(inputs), n_outputs=len(outputs), n_rngs=len(rngs) - ) - rv_op = rv_type( - inputs=inputs, - outputs=outputs, - signature=signature, - ) - return rv_op(size, *dist_params, *rngs) - - @staticmethod - def _infer_final_signature(signature: str, n_inputs, n_outputs, n_rngs) -> str: - """Add size and updates to user provided gufunc signature if they are missing.""" - - # Regex to split across outer commas - # Copied from https://stackoverflow.com/a/26634150 - outer_commas = re.compile(r",\s*(?![^()]*\))") - - input_sig, output_sig = signature.split("->") - # It's valid to have a signature without params inputs, as in a Flat RV - n_inputs_sig = len(outer_commas.split(input_sig)) if input_sig.strip() else 0 - n_outputs_sig = len(outer_commas.split(output_sig)) - - if n_inputs_sig == n_inputs and n_outputs_sig == n_outputs: - # User provided a signature with no missing parts - return signature - - size_sig = "[size]" - rngs_sig = ("[rng]",) * n_rngs - if n_inputs_sig == (n_inputs - n_rngs - 1): - # Assume size and rngs are missing - if input_sig.strip(): - input_sig = ",".join((size_sig, input_sig, *rngs_sig)) - else: - input_sig = ",".join((size_sig, *rngs_sig)) - if n_outputs_sig == (n_outputs - n_rngs): - # Assume updates are missing - output_sig = ",".join((output_sig, *rngs_sig)) - signature = "->".join((input_sig, output_sig)) - return signature - - -class CustomDist: - """A helper class to create custom distributions - - This class can be used to wrap black-box random and logp methods for use in - forward and mcmc sampling. - - A user can provide a `dist` function that returns a PyTensor graph built from - simpler PyMC distributions, which represents the distribution. This graph is - used to take random draws, and to infer the logp expression automatically - when not provided by the user. - - Alternatively, a user can provide a `random` function that returns numerical - draws (e.g., via NumPy routines), and a `logp` function that must return a - PyTensor graph that represents the logp graph when evaluated. This is used for - mcmc sampling. - - Additionally, a user can provide a `logcdf` and `support_point` functions that must return - PyTensor graphs that computes those quantities. These may be used by other PyMC - routines. - - Parameters - ---------- - name : str - dist_params : Tuple - A sequence of the distribution's parameter. These will be converted into - Pytensor tensor variables internally. - dist: Optional[Callable] - A callable that returns a PyTensor graph built from simpler PyMC distributions - which represents the distribution. This can be used by PyMC to take random draws - as well as to infer the logp of the distribution in some cases. In that case - it's not necessary to implement ``random`` or ``logp`` functions. - - It must have the following signature: ``dist(*dist_params, size)``. - The symbolic tensor distribution parameters are passed as positional arguments in - the same order as they are supplied when the ``CustomDist`` is constructed. - - random : Optional[Callable] - A callable that can be used to generate random draws from the distribution - - It must have the following signature: ``random(*dist_params, rng=None, size=None)``. - The numerical distribution parameters are passed as positional arguments in the - same order as they are supplied when the ``CustomDist`` is constructed. - The keyword arguments are ``rng``, which will provide the random variable's - associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent - the desired size of the random draw. If ``None``, a ``NotImplemented`` - error will be raised when trying to draw random samples from the distribution's - prior or posterior predictive. - - logp : Optional[Callable] - A callable that calculates the log probability of some given ``value`` - conditioned on certain distribution parameter values. It must have the - following signature: ``logp(value, *dist_params)``, where ``value`` is - an PyTensor tensor that represents the distribution value, and ``dist_params`` - are the tensors that hold the values of the distribution parameters. - This function must return an PyTensor tensor. - - When the `dist` function is specified, PyMC will try to automatically - infer the `logp` when this is not provided. - - Otherwise, a ``NotImplementedError`` will be raised when trying to compute the - distribution's logp. - logcdf : Optional[Callable] - A callable that calculates the log cumulative log probability of some given - ``value`` conditioned on certain distribution parameter values. It must have the - following signature: ``logcdf(value, *dist_params)``, where ``value`` is - an PyTensor tensor that represents the distribution value, and ``dist_params`` - are the tensors that hold the values of the distribution parameters. - This function must return an PyTensor tensor. If ``None``, a ``NotImplementedError`` - will be raised when trying to compute the distribution's logcdf. - support_point : Optional[Callable] - A callable that can be used to compute the finete logp point of the distribution. - It must have the following signature: ``support_point(rv, size, *rv_inputs)``. - The distribution's variable is passed as the first argument ``rv``. ``size`` - is the random variable's size implied by the ``dims``, ``size`` and parameters - supplied to the distribution. Finally, ``rv_inputs`` is the sequence of the - distribution parameters, in the same order as they were supplied when the - CustomDist was created. If ``None``, a default ``support_point`` function will be - assigned that will always return 0, or an array of zeros. - ndim_supp : Optional[int] - The number of dimensions in the support of the distribution. - Inferred from signature, if provided. Defaults to assuming - a scalar distribution, i.e. ``ndim_supp = 0`` - ndims_params : Optional[Sequence[int]] - The list of number of dimensions in the support of each of the distribution's - parameters. Inferred from signature, if provided. Defaults to assuming - all parameters are scalars, i.e. ``ndims_params=[0, ...]``. - signature : Optional[str] - A numpy vectorize-like signature that indicates the number and core dimensionality - of the input parameters and sample outputs of the CustomDist. - When specified, `ndim_supp` and `ndims_params` are not needed. See examples below. - dtype : str - The dtype of the distribution. All draws and observations passed into the - distribution will be cast onto this dtype. This is not needed if an PyTensor - dist function is provided, which should already return the right dtype! - class_name : str - Name for the class which will wrap the CustomDist methods. When not specified, - it will be given the name of the model variable. - kwargs : - Extra keyword arguments are passed to the parent's class ``__new__`` method. - - - Examples - -------- - Create a CustomDist that wraps a black-box logp function. This variable cannot be - used in prior or posterior predictive sampling because no random function was provided - - .. code-block:: python - - import numpy as np - import pymc as pm - from pytensor.tensor import TensorVariable - - def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable: - return -(value - mu)**2 - - with pm.Model(): - mu = pm.Normal('mu',0,1) - pm.CustomDist( - 'custom_dist', - mu, - logp=logp, - observed=np.random.randn(100), - ) - idata = pm.sample(100) - - Provide a random function that return numerical draws. This allows one to use a - CustomDist in prior and posterior predictive sampling. - A gufunc signature was also provided, which may be used by other routines. - - .. code-block:: python - - from typing import Optional, Tuple - - import numpy as np - import pymc as pm - from pytensor.tensor import TensorVariable - - def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable: - return -(value - mu)**2 - - def random( - mu: np.ndarray | float, - rng: Optional[np.random.Generator] = None, - size : Optional[Tuple[int]]=None, - ) -> np.ndarray | float : - return rng.normal(loc=mu, scale=1, size=size) - - with pm.Model(): - mu = pm.Normal('mu', 0 , 1) - pm.CustomDist( - 'custom_dist', - mu, - logp=logp, - random=random, - signature="()->()", - observed=np.random.randn(100, 3), - size=(100, 3), - ) - prior = pm.sample_prior_predictive(10) - - Provide a dist function that creates a PyTensor graph built from other - PyMC distributions. PyMC can automatically infer that the logp of this - variable corresponds to a shifted Exponential distribution. - A gufunc signature was also provided, which may be used by other routines. - - .. code-block:: python - - import pymc as pm - from pytensor.tensor import TensorVariable - - def dist( - lam: TensorVariable, - shift: TensorVariable, - size: TensorVariable, - ) -> TensorVariable: - return pm.Exponential.dist(lam, size=size) + shift - - with pm.Model() as m: - lam = pm.HalfNormal("lam") - shift = -1 - pm.CustomDist( - "custom_dist", - lam, - shift, - dist=dist, - signature="(),()->()", - observed=[-1, -1, 0], - ) - - prior = pm.sample_prior_predictive() - posterior = pm.sample() - - Provide a dist function that creates a PyTensor graph built from other - PyMC distributions. PyMC can automatically infer that the logp of - this variable corresponds to a modified-PERT distribution. - - .. code-block:: python - - import pymc as pm - from pytensor.tensor import TensorVariable - - def pert( - low: TensorVariable, - peak: TensorVariable, - high: TensorVariable, - lmbda: TensorVariable, - size: TensorVariable, - ) -> TensorVariable: - range = (high - low) - s_alpha = 1 + lmbda * (peak - low) / range - s_beta = 1 + lmbda * (high - peak) / range - return pm.Beta.dist(s_alpha, s_beta, size=size) * range + low - - with pm.Model() as m: - low = pm.Normal("low", 0, 10) - peak = pm.Normal("peak", 50, 10) - high = pm.Normal("high", 100, 10) - lmbda = 4 - pm.CustomDist("pert", low, peak, high, lmbda, dist=pert, observed=[30, 35, 73]) - - m.point_logps() - - """ - - def __new__( - cls, - name, - *dist_params, - dist: Callable | None = None, - random: Callable | None = None, - logp: Callable | None = None, - logcdf: Callable | None = None, - support_point: Callable | None = None, - # TODO: Deprecate ndim_supp / ndims_params in favor of signature? - ndim_supp: int | None = None, - ndims_params: Sequence[int] | None = None, - signature: str | None = None, - dtype: str = "floatX", - **kwargs, - ): - if isinstance(kwargs.get("observed", None), dict): - raise TypeError( - "Since ``v4.0.0`` the ``observed`` parameter should be of type" - " ``pd.Series``, ``np.array``, or ``pm.Data``." - " Previous versions allowed passing distribution parameters as" - " a dictionary in ``observed``, in the current version these " - "parameters are positional arguments." - ) - dist_params = cls.parse_dist_params(dist_params) - cls.check_valid_dist_random(dist, random, dist_params) - moment = kwargs.pop("moment", None) - if moment is not None: - warnings.warn( - "`moment` argument is deprecated. Use `support_point` instead.", - FutureWarning, - ) - support_point = moment - if dist is not None: - kwargs.setdefault("class_name", f"CustomDist_{name}") - return _CustomSymbolicDist( - name, - *dist_params, - dist=dist, - logp=logp, - logcdf=logcdf, - support_point=support_point, - ndim_supp=ndim_supp, - ndims_params=ndims_params, - signature=signature, - **kwargs, - ) - else: - kwargs.setdefault("class_name", f"CustomDist_{name}") - return _CustomDist( - name, - *dist_params, - random=random, - logp=logp, - logcdf=logcdf, - support_point=support_point, - ndim_supp=ndim_supp, - ndims_params=ndims_params, - signature=signature, - dtype=dtype, - **kwargs, - ) - - @classmethod - def dist( - cls, - *dist_params, - dist: Callable | None = None, - random: Callable | None = None, - logp: Callable | None = None, - logcdf: Callable | None = None, - support_point: Callable | None = None, - ndim_supp: int | None = None, - ndims_params: Sequence[int] | None = None, - signature: str | None = None, - dtype: str = "floatX", - **kwargs, - ): - dist_params = cls.parse_dist_params(dist_params) - cls.check_valid_dist_random(dist, random, dist_params) - if dist is not None: - return _CustomSymbolicDist.dist( - *dist_params, - dist=dist, - logp=logp, - logcdf=logcdf, - support_point=support_point, - ndim_supp=ndim_supp, - ndims_params=ndims_params, - signature=signature, - **kwargs, - ) - else: - return _CustomDist.dist( - *dist_params, - random=random, - logp=logp, - logcdf=logcdf, - support_point=support_point, - ndim_supp=ndim_supp, - ndims_params=ndims_params, - signature=signature, - dtype=dtype, - **kwargs, - ) - - @classmethod - def parse_dist_params(cls, dist_params): - if len(dist_params) > 0 and callable(dist_params[0]): - raise TypeError( - "The DensityDist API has changed, you are using the old API " - "where logp was the first positional argument. In the current API, " - "the logp is a keyword argument, amongst other changes. Please refer " - "to the API documentation for more information on how to use the " - "new DensityDist API." - ) - return [as_tensor_variable(param) for param in dist_params] - - @classmethod - def check_valid_dist_random(cls, dist, random, dist_params): - if dist is not None and random is not None: - raise ValueError("Cannot provide both dist and random functions") - if random is not None and cls.is_symbolic_random(random, dist_params): - raise TypeError( - "API change: function passed to `random` argument should no longer return a PyTensor graph. " - "Pass such function to the `dist` argument instead." - ) - - @classmethod - def is_symbolic_random(self, random, dist_params): - if random is None: - return False - # Try calling random with symbolic inputs - try: - size = normalize_size_param(None) - with new_or_existing_block_model_access( - error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API to create such variables." - ): - out = random(*dist_params, size) - except BlockModelAccessError: - raise - except Exception: - # If it fails we assume it was not - return False - # Confirm the output is symbolic - return isinstance(out, Variable) - - -DensityDist = CustomDist - - -def default_not_implemented(rv_name, method_name): - message = ( - f"Attempted to run {method_name} on the CustomDist '{rv_name}', " - f"but this method had not been provided when the distribution was " - f"constructed. Please re-build your model and provide a callable " - f"to '{rv_name}'s {method_name} keyword argument.\n" - ) - - def func(*args, **kwargs): - raise NotImplementedError(message) - - return func - - -def default_support_point(rv, size, *rv_inputs, rv_name=None, has_fallback=False): - if None not in rv.type.shape: - return pt.zeros(rv.type.shape) - elif rv.owner.op.ndim_supp == 0 and not rv_size_is_none(size): - return pt.zeros(size) - elif has_fallback: - return pt.zeros_like(rv) - else: - raise TypeError( - "Cannot safely infer the size of a multivariate random variable's support_point. " - f"Please provide a support_point function when instantiating the {rv_name} " - "random variable." - ) - - class DiracDeltaRV(SymbolicRandomVariable): name = "diracdelta" extended_signature = "[size],()->()" diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index 1b7fd7ddce..80dbcbf554 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -28,9 +28,9 @@ from pytensor.tensor.random.type import RandomType from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform +from pymc.distributions.custom import CustomSymbolicDistRV from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import ( - CustomSymbolicDistRV, Distribution, SymbolicRandomVariable, _support_point, diff --git a/pymc/gp/util.py b/pymc/gp/util.py index 4a4a18fda4..734d36ed32 100644 --- a/pymc/gp/util.py +++ b/pymc/gp/util.py @@ -22,13 +22,9 @@ from pytensor.tensor.variable import TensorConstant from scipy.cluster.vq import kmeans -# Avoid circular dependency when importing modelcontext -from pymc.distributions.distribution import Distribution -from pymc.model import modelcontext +from pymc.model.core import modelcontext from pymc.pytensorf import compile_pymc -_ = Distribution - JITTER_DEFAULT = 1e-6 diff --git a/pymc/model/core.py b/pymc/model/core.py index 2a57dde280..305128ea62 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -48,7 +48,6 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import GenTensorVariable, is_minibatch -from pymc.distributions.transforms import ChainedTransform, _default_transform from pymc.exceptions import ( BlockModelAccessError, ImputationWarning, @@ -1452,6 +1451,7 @@ def create_value_var( ------- TensorVariable """ + from pymc.distributions.transforms import ChainedTransform, _default_transform # Make the value variable a transformed value variable, # if there's an applicable transform diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 9aacc0ce59..2899a85a34 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -25,6 +25,7 @@ pymc/distributions/continuous.py pymc/distributions/dist_math.py pymc/distributions/distribution.py +pymc/distributions/custom.py pymc/distributions/multivariate.py pymc/distributions/timeseries.py pymc/distributions/truncated.py diff --git a/tests/distributions/test_custom.py b/tests/distributions/test_custom.py new file mode 100644 index 0000000000..96ba9685e1 --- /dev/null +++ b/tests/distributions/test_custom.py @@ -0,0 +1,661 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings + +import cloudpickle +import numpy as np +import pytensor +import pytest + +from numpy import random as npr +from pytensor import scan +from pytensor import tensor as pt +from scipy import stats as st + +from pymc.distributions import ( + Bernoulli, + Beta, + Categorical, + ChiSquared, + DiracDelta, + Flat, + HalfNormal, + LogNormal, + Mixture, + MvNormal, + Normal, + NormalMixture, + RandomWalk, + StudentT, + Truncated, + Uniform, +) +from pymc.distributions.custom import CustomDist, CustomDistRV, CustomSymbolicDistRV +from pymc.distributions.distribution import support_point +from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple +from pymc.distributions.transforms import log +from pymc.exceptions import BlockModelAccessError +from pymc.logprob import logcdf, logp +from pymc.model import Deterministic, Model +from pymc.pytensorf import collect_default_updates +from pymc.sampling import draw, sample, sample_posterior_predictive +from pymc.step_methods import Metropolis +from pymc.testing import assert_support_point_is_expected + + +class TestCustomDist: + @pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str) + def test_custom_dist_with_random(self, size): + with Model() as model: + mu = Normal("mu", 0, 1) + obs = CustomDist( + "custom_dist", + mu, + random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size), + observed=np.random.randn(100, *size), + ) + assert isinstance(obs.owner.op, CustomDistRV) + assert obs.eval().shape == (100, *size) + + def test_custom_dist_with_random_invalid_observed(self): + with pytest.raises( + TypeError, + match=( + "Since ``v4.0.0`` the ``observed`` parameter should be of type" + " ``pd.Series``, ``np.array``, or ``pm.Data``." + " Previous versions allowed passing distribution parameters as" + " a dictionary in ``observed``, in the current version these " + "parameters are positional arguments." + ), + ): + size = (3,) + with Model() as model: + mu = Normal("mu", 0, 1) + CustomDist( + "custom_dist", + mu, + random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size), + observed={"values": np.random.randn(100, *size)}, + ) + + def test_custom_dist_without_random(self): + with Model() as model: + mu = Normal("mu", 0, 1) + custom_dist = CustomDist( + "custom_dist", + mu, + logp=lambda value, mu: logp(Normal.dist(mu, 1, size=100), value), + observed=np.random.randn(100), + initval=0, + ) + assert isinstance(custom_dist.owner.op, CustomDistRV) + idata = sample(tune=50, draws=100, cores=1, step=Metropolis()) + + with pytest.raises(NotImplementedError): + sample_posterior_predictive(idata, model=model) + + @pytest.mark.xfail( + NotImplementedError, + reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388", + ) + @pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str) + def test_custom_dist_with_random_multivariate(self, size): + supp_shape = 5 + with Model() as model: + mu = Normal("mu", 0, 1, size=supp_shape) + obs = CustomDist( + "custom_dist", + mu, + random=lambda mu, rng=None, size=None: rng.multivariate_normal( + mean=mu, cov=np.eye(len(mu)), size=size + ), + observed=np.random.randn(100, *size, supp_shape), + ndims_params=[1], + ndim_supp=1, + ) + + assert isinstance(obs.owner.op, CustomDistRV) + assert obs.eval().shape == (100, *size, supp_shape) + + def test_serialize_custom_dist(self): + def func(x): + return -2 * (x**2).sum() + + def random(rng, size): + return rng.uniform(-2, 2, size=size) + + with Model(): + Normal("x") + y = CustomDist("y", logp=func, random=random) + y_dist = CustomDist.dist(logp=func, random=random) + Deterministic("y_dist", y_dist) + assert isinstance(y.owner.op, CustomDistRV) + assert isinstance(y_dist.owner.op, CustomDistRV) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + sample(draws=5, tune=1, mp_ctx="spawn") + + cloudpickle.loads(cloudpickle.dumps(y)) + cloudpickle.loads(cloudpickle.dumps(y_dist)) + + def test_custom_dist_old_api_error(self): + with Model(): + with pytest.raises( + TypeError, match="The DensityDist API has changed, you are using the old API" + ): + CustomDist("a", lambda x: x) + + @pytest.mark.xfail( + NotImplementedError, + reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388", + ) + @pytest.mark.parametrize("size", [None, (), (2,)], ids=str) + def test_custom_dist_multivariate_logp(self, size): + supp_shape = 5 + with Model() as model: + + def logp(value, mu): + return MvNormal.logp(value, mu, pt.eye(mu.shape[0])) + + mu = Normal("mu", size=supp_shape) + a = CustomDist("a", mu, logp=logp, ndims_params=[1], ndim_supp=1, size=size) + + assert isinstance(a.owner.op, CustomDistRV) + mu_test_value = npr.normal(loc=0, scale=1, size=supp_shape).astype(pytensor.config.floatX) + a_test_value = npr.normal( + loc=mu_test_value, scale=1, size=(*to_tuple(size), supp_shape) + ).astype(pytensor.config.floatX) + log_densityf = model.compile_logp(vars=[a], sum=False) + assert log_densityf({"a": a_test_value, "mu": mu_test_value})[0].shape == to_tuple(size) + + @pytest.mark.parametrize( + "support_point, size, expected", + [ + (None, None, 0.0), + (None, 5, np.zeros(5)), + ("custom_support_point", (), 5), + ("custom_support_point", (2, 5), np.full((2, 5), 5)), + ], + ) + def test_custom_dist_default_support_point_univariate(self, support_point, size, expected): + if support_point == "custom_support_point": + support_point = lambda rv, size, *rv_inputs: 5 * pt.ones(size, dtype=rv.dtype) # noqa E731 + with Model() as model: + x = CustomDist("x", support_point=support_point, size=size) + assert isinstance(x.owner.op, CustomDistRV) + assert_support_point_is_expected(model, expected, check_finite_logp=False) + + def test_custom_dist_moment_future_warning(self): + moment = lambda rv, size, *rv_inputs: 5 * pt.ones(size, dtype=rv.dtype) # noqa E731 + with Model() as model: + with pytest.warns( + FutureWarning, match="`moment` argument is deprecated. Use `support_point` instead." + ): + x = CustomDist("x", moment=moment, size=()) + assert_support_point_is_expected(model, 5, check_finite_logp=False) + + @pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) + def test_custom_dist_custom_support_point_univariate(self, size): + def density_support_point(rv, size, mu): + return (pt.ones(size) * mu).astype(rv.dtype) + + mu_val = np.array(np.random.normal(loc=2, scale=1)).astype(pytensor.config.floatX) + with Model(): + mu = Normal("mu") + a = CustomDist("a", mu, support_point=density_support_point, size=size) + assert isinstance(a.owner.op, CustomDistRV) + evaled_support_point = support_point(a).eval({mu: mu_val}) + assert evaled_support_point.shape == to_tuple(size) + assert np.all(evaled_support_point == mu_val) + + @pytest.mark.xfail( + NotImplementedError, + reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388", + ) + @pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) + def test_custom_dist_custom_support_point_multivariate(self, size): + def density_support_point(rv, size, mu): + return (pt.ones(size)[..., None] * mu).astype(rv.dtype) + + mu_val = np.random.normal(loc=2, scale=1, size=5).astype(pytensor.config.floatX) + with Model(): + mu = Normal("mu", size=5) + a = CustomDist( + "a", + mu, + support_point=density_support_point, + ndims_params=[1], + ndim_supp=1, + size=size, + ) + assert isinstance(a.owner.op, CustomDistRV) + evaled_support_point = support_point(a).eval({mu: mu_val}) + assert evaled_support_point.shape == (*to_tuple(size), 5) + assert np.all(evaled_support_point == mu_val) + + @pytest.mark.xfail( + NotImplementedError, + reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388", + ) + @pytest.mark.parametrize( + "with_random, size", + [ + (True, ()), + (True, (2,)), + (True, (3, 2)), + (False, ()), + (False, (2,)), + ], + ) + def test_custom_dist_default_support_point_multivariate(self, with_random, size): + def _random(mu, rng=None, size=None): + return rng.normal(mu, scale=1, size=to_tuple(size) + mu.shape) + + if with_random: + random = _random + else: + random = None + + mu_val = np.random.normal(loc=2, scale=1, size=5).astype(pytensor.config.floatX) + with Model(): + mu = Normal("mu", size=5) + a = CustomDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size) + assert isinstance(a.owner.op, CustomDistRV) + if with_random: + evaled_support_point = support_point(a).eval({mu: mu_val}) + assert evaled_support_point.shape == (*to_tuple(size), 5) + assert np.all(evaled_support_point == 0) + else: + with pytest.raises( + TypeError, + match="Cannot safely infer the size of a multivariate random variable's support_point.", + ): + evaled_support_point = support_point(a).eval({mu: mu_val}) + + def test_dist(self): + mu = 1 + x = CustomDist.dist( + mu, + logp=lambda value, mu: logp(Normal.dist(mu), value), + random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size), + shape=(3,), + ) + + x = cloudpickle.loads(cloudpickle.dumps(x)) + + test_value = draw(x, random_seed=1) + assert np.all(test_value == draw(x, random_seed=1)) + + x_logp = logp(x, test_value) + assert np.allclose(x_logp.eval(), st.norm(1).logpdf(test_value)) + + +class TestCustomSymbolicDist: + def test_basic(self): + def custom_dist(mu, sigma, size): + return pt.exp(Normal.dist(mu, sigma, size=size)) + + with Model() as m: + mu = Normal("mu") + sigma = HalfNormal("sigma") + lognormal = CustomDist( + "lognormal", + mu, + sigma, + dist=custom_dist, + size=(10,), + transform=log, + initval=np.ones(10), + ) + + assert isinstance(lognormal.owner.op, CustomSymbolicDistRV) + + # Fix mu and sigma, so that all source of randomness comes from the symbolic RV + draws = draw(lognormal, draws=3, givens={mu: 0.0, sigma: 1.0}) + assert draws.shape == (3, 10) + assert np.unique(draws).size == 30 + + with Model() as ref_m: + mu = Normal("mu") + sigma = HalfNormal("sigma") + LogNormal("lognormal", mu, sigma, size=(10,)) + + ip = m.initial_point() + np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip)) + + @pytest.mark.parametrize( + "dist_params, size, expected, dist_fn", + [ + ( + (5, 1), + None, + np.exp(5), + lambda mu, sigma, size: pt.exp(Normal.dist(mu, sigma, size=size)), + ), + ( + (2, np.ones(5)), + None, + np.exp(2 + np.ones(5)), + lambda mu, sigma, size: pt.exp(Normal.dist(mu, sigma, size=size) + 1.0), + ), + ( + (1, 2), + None, + np.sqrt(np.exp(1 + 0.5 * 2**2)), + lambda mu, sigma, size: pt.sqrt(LogNormal.dist(mu, sigma, size=size)), + ), + ( + (4,), + (3,), + np.log([4, 4, 4]), + lambda nu, size: pt.log(ChiSquared.dist(nu, size=size)), + ), + ( + (12, 1), + None, + 12, + lambda mu1, sigma, size: Normal.dist(mu1, sigma, size=size), + ), + ], + ) + def test_custom_dist_default_support_point(self, dist_params, size, expected, dist_fn): + with Model() as model: + CustomDist("x", *dist_params, dist=dist_fn, size=size) + assert_support_point_is_expected(model, expected) + + def test_custom_dist_default_support_point_scan(self): + def scan_step(left, right): + x = Uniform.dist(left, right) + x_update = collect_default_updates([x]) + return x, x_update + + def dist(size): + xs, updates = scan( + fn=scan_step, + sequences=[ + pt.as_tensor_variable(np.array([-4, -3])), + pt.as_tensor_variable(np.array([-2, -1])), + ], + name="xs", + ) + return xs + + with Model() as model: + CustomDist("x", dist=dist) + assert_support_point_is_expected(model, np.array([-3, -2])) + + def test_custom_dist_default_support_point_scan_recurring(self): + def scan_step(xtm1): + x = Normal.dist(xtm1 + 1) + x_update = collect_default_updates([x]) + return x, x_update + + def dist(size): + xs, _ = scan( + fn=scan_step, + outputs_info=pt.as_tensor_variable(np.array([0])).astype(float), + n_steps=3, + name="xs", + ) + return xs + + with Model() as model: + CustomDist("x", dist=dist) + assert_support_point_is_expected(model, np.array([[1], [2], [3]])) + + @pytest.mark.parametrize( + "left, right, size, expected", + [ + (-1, 1, None, 0 + 5), + (-3, -1, None, -2 + 5), + (-3, 1, (3,), np.array([-1 + 5, -1 + 5, -1 + 5])), + ], + ) + def test_custom_dist_default_support_point_nested(self, left, right, size, expected): + def dist_fn(left, right, size): + return Truncated.dist(Normal.dist(0, 1), left, right, size=size) + 5 + + with Model() as model: + CustomDist("x", left, right, size=size, dist=dist_fn) + assert_support_point_is_expected(model, expected) + + def test_logcdf_inference(self): + def custom_dist(mu, sigma, size): + return pt.exp(Normal.dist(mu, sigma, size=size)) + + mu = 1 + sigma = 1.25 + test_value = 0.9 + + custom_lognormal = CustomDist.dist(mu, sigma, dist=custom_dist) + ref_lognormal = LogNormal.dist(mu, sigma) + + np.testing.assert_allclose( + logcdf(custom_lognormal, test_value).eval(), + logcdf(ref_lognormal, test_value).eval(), + ) + + def test_random_multiple_rngs(self): + def custom_dist(p, sigma, size): + idx = Bernoulli.dist(p=p) + if rv_size_is_none(size): + size = pt.broadcast_shape(p, sigma) + comps = Normal.dist([-sigma, sigma], 1e-1, size=(*size, 2)).T + return comps[idx] + + customdist = CustomDist.dist( + 0.5, + 10.0, + dist=custom_dist, + size=(10,), + ) + + assert isinstance(customdist.owner.op, CustomSymbolicDistRV) + + node = customdist.owner + assert len(node.inputs) == 5 # Size, 2 inputs and 2 RNGs + assert len(node.outputs) == 3 # RV and 2 updated RNGs + assert len(node.op.update(node)) == 2 + + draws = draw(customdist, draws=2, random_seed=123) + assert np.unique(draws).size == 20 + + def test_custom_methods(self): + def custom_dist(mu, size): + return DiracDelta.dist(mu, size=size) + + def custom_support_point(rv, size, mu): + return pt.full_like(rv, mu + 1) + + def custom_logp(value, mu): + return pt.full_like(value, mu + 2) + + def custom_logcdf(value, mu): + return pt.full_like(value, mu + 3) + + customdist = CustomDist.dist( + [np.e, np.e], + dist=custom_dist, + support_point=custom_support_point, + logp=custom_logp, + logcdf=custom_logcdf, + ) + + assert isinstance(customdist.owner.op, CustomSymbolicDistRV) + + np.testing.assert_allclose(draw(customdist), [np.e, np.e]) + np.testing.assert_allclose(support_point(customdist).eval(), [np.e + 1, np.e + 1]) + np.testing.assert_allclose(logp(customdist, [0, 0]).eval(), [np.e + 2, np.e + 2]) + np.testing.assert_allclose(logcdf(customdist, [0, 0]).eval(), [np.e + 3, np.e + 3]) + + def test_change_size(self): + def custom_dist(mu, sigma, size): + return pt.exp(Normal.dist(mu, sigma, size=size)) + + lognormal = CustomDist.dist( + 0, + 1, + dist=custom_dist, + size=(10,), + ) + assert isinstance(lognormal.owner.op, CustomSymbolicDistRV) + assert tuple(lognormal.shape.eval()) == (10,) + + new_lognormal = change_dist_size(lognormal, new_size=(2, 5)) + assert isinstance(new_lognormal.owner.op, CustomSymbolicDistRV) + assert tuple(new_lognormal.shape.eval()) == (2, 5) + + new_lognormal = change_dist_size(lognormal, new_size=(2, 5), expand=True) + assert isinstance(new_lognormal.owner.op, CustomSymbolicDistRV) + assert tuple(new_lognormal.shape.eval()) == (2, 5, 10) + + def test_error_model_access(self): + def custom_dist(size): + return Flat("Flat", size=size) + + with Model() as m: + with pytest.raises( + BlockModelAccessError, + match="Model variables cannot be created in the dist function", + ): + CustomDist("custom_dist", dist=custom_dist) + + def test_api_change_error(self): + def old_random(size): + return Flat.dist(size=size) + + # Old API raises + with pytest.raises(TypeError, match="API change: function passed to `random` argument"): + CustomDist.dist(random=old_random, class_name="custom_dist") + + # New API is fine + CustomDist.dist(dist=old_random, class_name="custom_dist") + + def test_scan(self): + def trw(nu, sigma, steps, size): + if rv_size_is_none(size): + size = () + + def step(xtm1, nu, sigma): + x = StudentT.dist(nu=nu, mu=xtm1, sigma=sigma, shape=size) + return x, collect_default_updates([x]) + + xs, _ = scan( + fn=step, + outputs_info=pt.zeros(size), + non_sequences=[nu, sigma], + n_steps=steps, + ) + + # Logprob inference cannot be derived yet https://github.com/pymc-devs/pymc/issues/6360 + # xs = swapaxes(xs, 0, -1) + + return xs + + nu = 4 + sigma = 0.7 + steps = 99 + batch_size = 3 + x = CustomDist.dist(nu, sigma, steps, dist=trw, size=batch_size) + + x_draw = draw(x, random_seed=1) + assert x_draw.shape == (steps, batch_size) + np.testing.assert_allclose(draw(x, random_seed=1), x_draw) + assert not np.any(draw(x, random_seed=2) == x_draw) + + ref_dist = RandomWalk.dist( + init_dist=Flat.dist(), + innovation_dist=StudentT.dist(nu=nu, sigma=sigma), + steps=steps, + size=(batch_size,), + ) + ref_val = pt.concatenate([np.zeros((1, batch_size)), x_draw]).T + + np.testing.assert_allclose( + logp(x, x_draw).eval().sum(0), + logp(ref_dist, ref_val).eval(), + ) + + def test_inferred_logp_mixture(self): + import numpy as np + + def shifted_normal(mu, sigma, size): + return mu + Normal.dist(0, sigma, shape=size) + + mus = [3.5, -4.3] + sds = [1.5, 2.3] + w = [0.3, 0.7] + with Model() as m: + comp_dists = [ + CustomDist.dist(mus[0], sds[0], dist=shifted_normal), + CustomDist.dist(mus[1], sds[1], dist=shifted_normal), + ] + Mixture("mix", w=w, comp_dists=comp_dists) + + test_value = 0.1 + np.testing.assert_allclose( + m.compile_logp()({"mix": test_value}), + logp(NormalMixture.dist(w=w, mu=mus, sigma=sds), test_value).eval(), + ) + + def test_symbolic_dist(self): + # Test we can create a SymbolicDist inside a CustomDist + def dist(size): + return Truncated.dist(Beta.dist(1, 1, size=size), lower=0.1, upper=0.9) + + assert CustomDist.dist(dist=dist) + + def test_nested_custom_dist(self): + """Test we can create CustomDist that creates another CustomDist""" + + def dist(size=None): + def inner_dist(size=None): + return Normal.dist(size=size) + + inner_dist = CustomDist.dist(dist=inner_dist, size=size) + return pt.exp(inner_dist) + + rv = CustomDist.dist(dist=dist) + np.testing.assert_allclose( + logp(rv, 1.0).eval(), + logp(LogNormal.dist(), 1.0).eval(), + ) + + def test_signature(self): + def dist(p, size): + return -Categorical.dist(p=p, size=size) + + out = CustomDist.dist([0.25, 0.75], dist=dist, signature="(p)->()") + # Size and updates are added automatically to the signature + assert out.owner.op.extended_signature == "[size],(p),[rng]->(),[rng]" + assert out.owner.op.ndim_supp == 0 + assert out.owner.op.ndims_params == [1] + + # When recreated internally, the whole signature may already be known + out = CustomDist.dist([0.25, 0.75], dist=dist, signature="[size],(p),[rng]->(),[rng]") + assert out.owner.op.extended_signature == "[size],(p),[rng]->(),[rng]" + assert out.owner.op.ndim_supp == 0 + assert out.owner.op.ndims_params == [1] + + # A safe signature can be inferred from ndim_supp and ndims_params + out = CustomDist.dist([0.25, 0.75], dist=dist, ndim_supp=0, ndims_params=[1]) + assert out.owner.op.extended_signature == "[size],(i00),[rng]->(),[rng]" + assert out.owner.op.ndim_supp == 0 + assert out.owner.op.ndims_params == [1] + + # Otherwise be default we assume everything is scalar, even though it's wrong in this case + out = CustomDist.dist([0.25, 0.75], dist=dist) + assert out.owner.op.extended_signature == "[size],(),[rng]->(),[rng]" + assert out.owner.op.ndim_supp == 0 + assert out.owner.op.ndims_params == [0] diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index c87336d409..a0668c8bda 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -14,7 +14,6 @@ import sys import warnings -import cloudpickle import numpy as np import numpy.random as npr import numpy.testing as npt @@ -23,7 +22,7 @@ import pytest import scipy.stats as st -from pytensor import scan, shared +from pytensor import shared from pytensor.tensor import TensorVariable import pymc as pm @@ -31,30 +30,20 @@ from pymc.distributions import ( Censored, Flat, - HalfNormal, - LogNormal, MvNormal, MvStudentT, Normal, ) from pymc.distributions.distribution import ( - CustomDist, - CustomDistRV, - CustomSymbolicDistRV, - DiracDelta, PartialObservedRV, SymbolicRandomVariable, _support_point, create_partial_observed_rv, support_point, ) -from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple -from pymc.distributions.transforms import log -from pymc.exceptions import BlockModelAccessError -from pymc.logprob.basic import conditional_logp, logcdf, logp -from pymc.model import Deterministic, Model -from pymc.pytensorf import collect_default_updates, compile_pymc -from pymc.sampling import draw, sample +from pymc.distributions.shape_utils import change_dist_size +from pymc.logprob.basic import conditional_logp, logp +from pymc.pytensorf import compile_pymc from pymc.testing import ( BaseTestDistributionRandom, I, @@ -166,615 +155,6 @@ def test_all_distributions_have_support_points(): ) -class TestCustomDist: - @pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str) - def test_custom_dist_with_random(self, size): - with Model() as model: - mu = Normal("mu", 0, 1) - obs = CustomDist( - "custom_dist", - mu, - random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size), - observed=np.random.randn(100, *size), - ) - assert isinstance(obs.owner.op, CustomDistRV) - assert obs.eval().shape == (100, *size) - - def test_custom_dist_with_random_invalid_observed(self): - with pytest.raises( - TypeError, - match=( - "Since ``v4.0.0`` the ``observed`` parameter should be of type" - " ``pd.Series``, ``np.array``, or ``pm.Data``." - " Previous versions allowed passing distribution parameters as" - " a dictionary in ``observed``, in the current version these " - "parameters are positional arguments." - ), - ): - size = (3,) - with Model() as model: - mu = Normal("mu", 0, 1) - CustomDist( - "custom_dist", - mu, - random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size), - observed={"values": np.random.randn(100, *size)}, - ) - - def test_custom_dist_without_random(self): - with Model() as model: - mu = Normal("mu", 0, 1) - custom_dist = CustomDist( - "custom_dist", - mu, - logp=lambda value, mu: logp(pm.Normal.dist(mu, 1, size=100), value), - observed=np.random.randn(100), - initval=0, - ) - assert isinstance(custom_dist.owner.op, CustomDistRV) - idata = sample(tune=50, draws=100, cores=1, step=pm.Metropolis()) - - with pytest.raises(NotImplementedError): - pm.sample_posterior_predictive(idata, model=model) - - @pytest.mark.xfail( - NotImplementedError, - reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388", - ) - @pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str) - def test_custom_dist_with_random_multivariate(self, size): - supp_shape = 5 - with Model() as model: - mu = Normal("mu", 0, 1, size=supp_shape) - obs = CustomDist( - "custom_dist", - mu, - random=lambda mu, rng=None, size=None: rng.multivariate_normal( - mean=mu, cov=np.eye(len(mu)), size=size - ), - observed=np.random.randn(100, *size, supp_shape), - ndims_params=[1], - ndim_supp=1, - ) - - assert isinstance(obs.owner.op, CustomDistRV) - assert obs.eval().shape == (100, *size, supp_shape) - - def test_serialize_custom_dist(self): - def func(x): - return -2 * (x**2).sum() - - def random(rng, size): - return rng.uniform(-2, 2, size=size) - - with Model(): - Normal("x") - y = CustomDist("y", logp=func, random=random) - y_dist = CustomDist.dist(logp=func, random=random) - Deterministic("y_dist", y_dist) - assert isinstance(y.owner.op, CustomDistRV) - assert isinstance(y_dist.owner.op, CustomDistRV) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - sample(draws=5, tune=1, mp_ctx="spawn") - - cloudpickle.loads(cloudpickle.dumps(y)) - cloudpickle.loads(cloudpickle.dumps(y_dist)) - - def test_custom_dist_old_api_error(self): - with Model(): - with pytest.raises( - TypeError, match="The DensityDist API has changed, you are using the old API" - ): - CustomDist("a", lambda x: x) - - @pytest.mark.xfail( - NotImplementedError, - reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388", - ) - @pytest.mark.parametrize("size", [None, (), (2,)], ids=str) - def test_custom_dist_multivariate_logp(self, size): - supp_shape = 5 - with Model() as model: - - def logp(value, mu): - return pm.MvNormal.logp(value, mu, pt.eye(mu.shape[0])) - - mu = Normal("mu", size=supp_shape) - a = CustomDist("a", mu, logp=logp, ndims_params=[1], ndim_supp=1, size=size) - - assert isinstance(a.owner.op, CustomDistRV) - mu_test_value = npr.normal(loc=0, scale=1, size=supp_shape).astype(pytensor.config.floatX) - a_test_value = npr.normal( - loc=mu_test_value, scale=1, size=(*to_tuple(size), supp_shape) - ).astype(pytensor.config.floatX) - log_densityf = model.compile_logp(vars=[a], sum=False) - assert log_densityf({"a": a_test_value, "mu": mu_test_value})[0].shape == to_tuple(size) - - @pytest.mark.parametrize( - "support_point, size, expected", - [ - (None, None, 0.0), - (None, 5, np.zeros(5)), - ("custom_support_point", (), 5), - ("custom_support_point", (2, 5), np.full((2, 5), 5)), - ], - ) - def test_custom_dist_default_support_point_univariate(self, support_point, size, expected): - if support_point == "custom_support_point": - support_point = lambda rv, size, *rv_inputs: 5 * pt.ones(size, dtype=rv.dtype) # noqa E731 - with pm.Model() as model: - x = CustomDist("x", support_point=support_point, size=size) - assert isinstance(x.owner.op, CustomDistRV) - assert_support_point_is_expected(model, expected, check_finite_logp=False) - - def test_custom_dist_moment_future_warning(self): - moment = lambda rv, size, *rv_inputs: 5 * pt.ones(size, dtype=rv.dtype) # noqa E731 - with pm.Model() as model: - with pytest.warns( - FutureWarning, match="`moment` argument is deprecated. Use `support_point` instead." - ): - x = CustomDist("x", moment=moment, size=()) - assert_support_point_is_expected(model, 5, check_finite_logp=False) - - @pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) - def test_custom_dist_custom_support_point_univariate(self, size): - def density_support_point(rv, size, mu): - return (pt.ones(size) * mu).astype(rv.dtype) - - mu_val = np.array(np.random.normal(loc=2, scale=1)).astype(pytensor.config.floatX) - with Model(): - mu = Normal("mu") - a = CustomDist("a", mu, support_point=density_support_point, size=size) - assert isinstance(a.owner.op, CustomDistRV) - evaled_support_point = support_point(a).eval({mu: mu_val}) - assert evaled_support_point.shape == to_tuple(size) - assert np.all(evaled_support_point == mu_val) - - @pytest.mark.xfail( - NotImplementedError, - reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388", - ) - @pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) - def test_custom_dist_custom_support_point_multivariate(self, size): - def density_support_point(rv, size, mu): - return (pt.ones(size)[..., None] * mu).astype(rv.dtype) - - mu_val = np.random.normal(loc=2, scale=1, size=5).astype(pytensor.config.floatX) - with Model(): - mu = Normal("mu", size=5) - a = CustomDist( - "a", - mu, - support_point=density_support_point, - ndims_params=[1], - ndim_supp=1, - size=size, - ) - assert isinstance(a.owner.op, CustomDistRV) - evaled_support_point = support_point(a).eval({mu: mu_val}) - assert evaled_support_point.shape == (*to_tuple(size), 5) - assert np.all(evaled_support_point == mu_val) - - @pytest.mark.xfail( - NotImplementedError, - reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388", - ) - @pytest.mark.parametrize( - "with_random, size", - [ - (True, ()), - (True, (2,)), - (True, (3, 2)), - (False, ()), - (False, (2,)), - ], - ) - def test_custom_dist_default_support_point_multivariate(self, with_random, size): - def _random(mu, rng=None, size=None): - return rng.normal(mu, scale=1, size=to_tuple(size) + mu.shape) - - if with_random: - random = _random - else: - random = None - - mu_val = np.random.normal(loc=2, scale=1, size=5).astype(pytensor.config.floatX) - with Model(): - mu = Normal("mu", size=5) - a = CustomDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size) - assert isinstance(a.owner.op, CustomDistRV) - if with_random: - evaled_support_point = support_point(a).eval({mu: mu_val}) - assert evaled_support_point.shape == (*to_tuple(size), 5) - assert np.all(evaled_support_point == 0) - else: - with pytest.raises( - TypeError, - match="Cannot safely infer the size of a multivariate random variable's support_point.", - ): - evaled_support_point = support_point(a).eval({mu: mu_val}) - - def test_dist(self): - mu = 1 - x = pm.CustomDist.dist( - mu, - logp=lambda value, mu: pm.logp(pm.Normal.dist(mu), value), - random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size), - shape=(3,), - ) - - x = cloudpickle.loads(cloudpickle.dumps(x)) - - test_value = pm.draw(x, random_seed=1) - assert np.all(test_value == pm.draw(x, random_seed=1)) - - x_logp = pm.logp(x, test_value) - assert np.allclose(x_logp.eval(), st.norm(1).logpdf(test_value)) - - -class TestCustomSymbolicDist: - def test_basic(self): - def custom_dist(mu, sigma, size): - return pt.exp(pm.Normal.dist(mu, sigma, size=size)) - - with Model() as m: - mu = Normal("mu") - sigma = HalfNormal("sigma") - lognormal = CustomDist( - "lognormal", - mu, - sigma, - dist=custom_dist, - size=(10,), - transform=log, - initval=np.ones(10), - ) - - assert isinstance(lognormal.owner.op, CustomSymbolicDistRV) - - # Fix mu and sigma, so that all source of randomness comes from the symbolic RV - draws = pm.draw(lognormal, draws=3, givens={mu: 0.0, sigma: 1.0}) - assert draws.shape == (3, 10) - assert np.unique(draws).size == 30 - - with Model() as ref_m: - mu = Normal("mu") - sigma = HalfNormal("sigma") - LogNormal("lognormal", mu, sigma, size=(10,)) - - ip = m.initial_point() - np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip)) - - @pytest.mark.parametrize( - "dist_params, size, expected, dist_fn", - [ - ( - (5, 1), - None, - np.exp(5), - lambda mu, sigma, size: pt.exp(pm.Normal.dist(mu, sigma, size=size)), - ), - ( - (2, np.ones(5)), - None, - np.exp(2 + np.ones(5)), - lambda mu, sigma, size: pt.exp(pm.Normal.dist(mu, sigma, size=size) + 1.0), - ), - ( - (1, 2), - None, - np.sqrt(np.exp(1 + 0.5 * 2**2)), - lambda mu, sigma, size: pt.sqrt(pm.LogNormal.dist(mu, sigma, size=size)), - ), - ( - (4,), - (3,), - np.log([4, 4, 4]), - lambda nu, size: pt.log(pm.ChiSquared.dist(nu, size=size)), - ), - ( - (12, 1), - None, - 12, - lambda mu1, sigma, size: pm.Normal.dist(mu1, sigma, size=size), - ), - ], - ) - def test_custom_dist_default_support_point(self, dist_params, size, expected, dist_fn): - with Model() as model: - CustomDist("x", *dist_params, dist=dist_fn, size=size) - assert_support_point_is_expected(model, expected) - - def test_custom_dist_default_support_point_scan(self): - def scan_step(left, right): - x = pm.Uniform.dist(left, right) - x_update = collect_default_updates([x]) - return x, x_update - - def dist(size): - xs, updates = scan( - fn=scan_step, - sequences=[ - pt.as_tensor_variable(np.array([-4, -3])), - pt.as_tensor_variable(np.array([-2, -1])), - ], - name="xs", - ) - return xs - - with Model() as model: - CustomDist("x", dist=dist) - assert_support_point_is_expected(model, np.array([-3, -2])) - - def test_custom_dist_default_support_point_scan_recurring(self): - def scan_step(xtm1): - x = pm.Normal.dist(xtm1 + 1) - x_update = collect_default_updates([x]) - return x, x_update - - def dist(size): - xs, _ = scan( - fn=scan_step, - outputs_info=pt.as_tensor_variable(np.array([0])).astype(float), - n_steps=3, - name="xs", - ) - return xs - - with Model() as model: - CustomDist("x", dist=dist) - assert_support_point_is_expected(model, np.array([[1], [2], [3]])) - - @pytest.mark.parametrize( - "left, right, size, expected", - [ - (-1, 1, None, 0 + 5), - (-3, -1, None, -2 + 5), - (-3, 1, (3,), np.array([-1 + 5, -1 + 5, -1 + 5])), - ], - ) - def test_custom_dist_default_support_point_nested(self, left, right, size, expected): - def dist_fn(left, right, size): - return pm.Truncated.dist(pm.Normal.dist(0, 1), left, right, size=size) + 5 - - with Model() as model: - CustomDist("x", left, right, size=size, dist=dist_fn) - assert_support_point_is_expected(model, expected) - - def test_logcdf_inference(self): - def custom_dist(mu, sigma, size): - return pt.exp(pm.Normal.dist(mu, sigma, size=size)) - - mu = 1 - sigma = 1.25 - test_value = 0.9 - - custom_lognormal = CustomDist.dist(mu, sigma, dist=custom_dist) - ref_lognormal = LogNormal.dist(mu, sigma) - - np.testing.assert_allclose( - pm.logcdf(custom_lognormal, test_value).eval(), - pm.logcdf(ref_lognormal, test_value).eval(), - ) - - def test_random_multiple_rngs(self): - def custom_dist(p, sigma, size): - idx = pm.Bernoulli.dist(p=p) - if rv_size_is_none(size): - size = pt.broadcast_shape(p, sigma) - comps = pm.Normal.dist([-sigma, sigma], 1e-1, size=(*size, 2)).T - return comps[idx] - - customdist = CustomDist.dist( - 0.5, - 10.0, - dist=custom_dist, - size=(10,), - ) - - assert isinstance(customdist.owner.op, CustomSymbolicDistRV) - - node = customdist.owner - assert len(node.inputs) == 5 # Size, 2 inputs and 2 RNGs - assert len(node.outputs) == 3 # RV and 2 updated RNGs - assert len(node.op.update(node)) == 2 - - draws = pm.draw(customdist, draws=2, random_seed=123) - assert np.unique(draws).size == 20 - - def test_custom_methods(self): - def custom_dist(mu, size): - return DiracDelta.dist(mu, size=size) - - def custom_support_point(rv, size, mu): - return pt.full_like(rv, mu + 1) - - def custom_logp(value, mu): - return pt.full_like(value, mu + 2) - - def custom_logcdf(value, mu): - return pt.full_like(value, mu + 3) - - customdist = CustomDist.dist( - [np.e, np.e], - dist=custom_dist, - support_point=custom_support_point, - logp=custom_logp, - logcdf=custom_logcdf, - ) - - assert isinstance(customdist.owner.op, CustomSymbolicDistRV) - - np.testing.assert_allclose(draw(customdist), [np.e, np.e]) - np.testing.assert_allclose(support_point(customdist).eval(), [np.e + 1, np.e + 1]) - np.testing.assert_allclose(logp(customdist, [0, 0]).eval(), [np.e + 2, np.e + 2]) - np.testing.assert_allclose(logcdf(customdist, [0, 0]).eval(), [np.e + 3, np.e + 3]) - - def test_change_size(self): - def custom_dist(mu, sigma, size): - return pt.exp(pm.Normal.dist(mu, sigma, size=size)) - - lognormal = CustomDist.dist( - 0, - 1, - dist=custom_dist, - size=(10,), - ) - assert isinstance(lognormal.owner.op, CustomSymbolicDistRV) - assert tuple(lognormal.shape.eval()) == (10,) - - new_lognormal = change_dist_size(lognormal, new_size=(2, 5)) - assert isinstance(new_lognormal.owner.op, CustomSymbolicDistRV) - assert tuple(new_lognormal.shape.eval()) == (2, 5) - - new_lognormal = change_dist_size(lognormal, new_size=(2, 5), expand=True) - assert isinstance(new_lognormal.owner.op, CustomSymbolicDistRV) - assert tuple(new_lognormal.shape.eval()) == (2, 5, 10) - - def test_error_model_access(self): - def custom_dist(size): - return pm.Flat("Flat", size=size) - - with pm.Model() as m: - with pytest.raises( - BlockModelAccessError, - match="Model variables cannot be created in the dist function", - ): - CustomDist("custom_dist", dist=custom_dist) - - def test_api_change_error(self): - def old_random(size): - return pm.Flat.dist(size=size) - - # Old API raises - with pytest.raises(TypeError, match="API change: function passed to `random` argument"): - pm.CustomDist.dist(random=old_random, class_name="custom_dist") - - # New API is fine - pm.CustomDist.dist(dist=old_random, class_name="custom_dist") - - def test_scan(self): - def trw(nu, sigma, steps, size): - if rv_size_is_none(size): - size = () - - def step(xtm1, nu, sigma): - x = pm.StudentT.dist(nu=nu, mu=xtm1, sigma=sigma, shape=size) - return x, collect_default_updates([x]) - - xs, _ = scan( - fn=step, - outputs_info=pt.zeros(size), - non_sequences=[nu, sigma], - n_steps=steps, - ) - - # Logprob inference cannot be derived yet https://github.com/pymc-devs/pymc/issues/6360 - # xs = swapaxes(xs, 0, -1) - - return xs - - nu = 4 - sigma = 0.7 - steps = 99 - batch_size = 3 - x = CustomDist.dist(nu, sigma, steps, dist=trw, size=batch_size) - - x_draw = pm.draw(x, random_seed=1) - assert x_draw.shape == (steps, batch_size) - np.testing.assert_allclose(pm.draw(x, random_seed=1), x_draw) - assert not np.any(pm.draw(x, random_seed=2) == x_draw) - - ref_dist = pm.RandomWalk.dist( - init_dist=pm.Flat.dist(), - innovation_dist=pm.StudentT.dist(nu=nu, sigma=sigma), - steps=steps, - size=(batch_size,), - ) - ref_val = pt.concatenate([np.zeros((1, batch_size)), x_draw]).T - - np.testing.assert_allclose( - pm.logp(x, x_draw).eval().sum(0), - pm.logp(ref_dist, ref_val).eval(), - ) - - def test_inferred_logp_mixture(self): - import numpy as np - - import pymc as pm - - def shifted_normal(mu, sigma, size): - return mu + pm.Normal.dist(0, sigma, shape=size) - - mus = [3.5, -4.3] - sds = [1.5, 2.3] - w = [0.3, 0.7] - with pm.Model() as m: - comp_dists = [ - pm.DensityDist.dist(mus[0], sds[0], dist=shifted_normal), - pm.DensityDist.dist(mus[1], sds[1], dist=shifted_normal), - ] - pm.Mixture("mix", w=w, comp_dists=comp_dists) - - test_value = 0.1 - np.testing.assert_allclose( - m.compile_logp()({"mix": test_value}), - pm.logp(pm.NormalMixture.dist(w=w, mu=mus, sigma=sds), test_value).eval(), - ) - - def test_symbolic_dist(self): - # Test we can create a SymbolicDist inside a CustomDist - def dist(size): - return pm.Truncated.dist(pm.Beta.dist(1, 1, size=size), lower=0.1, upper=0.9) - - assert pm.CustomDist.dist(dist=dist) - - def test_nested_custom_dist(self): - """Test we can create CustomDist that creates another CustomDist""" - - def dist(size=None): - def inner_dist(size=None): - return pm.Normal.dist(size=size) - - inner_dist = pm.CustomDist.dist(dist=inner_dist, size=size) - return pt.exp(inner_dist) - - rv = pm.CustomDist.dist(dist=dist) - np.testing.assert_allclose( - pm.logp(rv, 1.0).eval(), - pm.logp(pm.LogNormal.dist(), 1.0).eval(), - ) - - def test_signature(self): - def dist(p, size): - return -pm.Categorical.dist(p=p, size=size) - - out = CustomDist.dist([0.25, 0.75], dist=dist, signature="(p)->()") - # Size and updates are added automatically to the signature - assert out.owner.op.extended_signature == "[size],(p),[rng]->(),[rng]" - assert out.owner.op.ndim_supp == 0 - assert out.owner.op.ndims_params == [1] - - # When recreated internally, the whole signature may already be known - out = CustomDist.dist([0.25, 0.75], dist=dist, signature="[size],(p),[rng]->(),[rng]") - assert out.owner.op.extended_signature == "[size],(p),[rng]->(),[rng]" - assert out.owner.op.ndim_supp == 0 - assert out.owner.op.ndims_params == [1] - - # A safe signature can be inferred from ndim_supp and ndims_params - out = CustomDist.dist([0.25, 0.75], dist=dist, ndim_supp=0, ndims_params=[1]) - assert out.owner.op.extended_signature == "[size],(i00),[rng]->(),[rng]" - assert out.owner.op.ndim_supp == 0 - assert out.owner.op.ndims_params == [1] - - # Otherwise be default we assume everything is scalar, even though it's wrong in this case - out = CustomDist.dist([0.25, 0.75], dist=dist) - assert out.owner.op.extended_signature == "[size],(),[rng]->(),[rng]" - assert out.owner.op.ndim_supp == 0 - assert out.owner.op.ndims_params == [0] - - class TestSymbolicRandomVariable: def test_inline(self): class TestSymbolicRV(SymbolicRandomVariable):