diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6cdfa3fb2c..18157e434a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -145,7 +145,7 @@ jobs: # PyTensor next, pip installs a lower version of numpy via the PyPI. if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi - if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi + if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi pip install -e ./ mamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index bb63bf3135..6f63474de4 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -1,4 +1,6 @@ import functools +import typing +from typing import Callable, Optional import jax import jax.numpy as jnp @@ -18,7 +20,21 @@ Second, Sub, ) -from pytensor.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi +from pytensor.scalar.math import Erf, Erfc, Erfcinv, Erfcx, Erfinv, Iv, Log1mexp, Psi + + +def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Callable: + try: + import tensorflow_probability.substrates.jax.math as tfp_jax_math + except ModuleNotFoundError: + raise NotImplementedError( + f"No JAX implementation for Op {op.name}. " + "Implementation is available if TensorFlow Probability is installed" + ) + + if jax_op_name is None: + jax_op_name = op.name + return typing.cast(Callable, getattr(tfp_jax_math, jax_op_name)) def check_if_inputs_scalars(node): @@ -211,6 +227,24 @@ def erfinv(x): return erfinv +@jax_funcify.register(Erfcx) +@jax_funcify.register(Erfcinv) +def jax_funcify_from_tfp(op, **kwargs): + tfp_jax_op = try_import_tfp_jax_op(op) + + return tfp_jax_op + + +@jax_funcify.register(Iv) +def jax_funcify_Iv(op, **kwargs): + ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive") + + def iv(v, x): + return ive(v, x) / jnp.exp(-jnp.abs(jnp.real(x))) + + return iv + + @jax_funcify.register(Log1mexp) def jax_funcify_Log1mexp(op, node, **kwargs): def log1mexp(x): diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 8d428f450f..9691dd535c 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -7,13 +7,17 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value from pytensor.scalar.basic import Composite +from pytensor.tensor import as_tensor from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import all as at_all from pytensor.tensor.math import ( cosh, erf, erfc, + erfcinv, + erfcx, erfinv, + iv, log, log1mexp, psi, @@ -28,6 +32,14 @@ from pytensor.link.jax.dispatch import jax_funcify +try: + pass + + TFP_INSTALLED = True +except ModuleNotFoundError: + TFP_INSTALLED = False + + def test_second(): a0 = scalar("a0") b = scalar("b") @@ -134,6 +146,23 @@ def test_erfinv(): compare_jax_and_py(fg, [0.95]) +@pytest.mark.parametrize( + "op, test_values", + [ + (erfcx, (0.7,)), + (erfcinv, (0.7,)), + (iv, (0.3, 0.7)), + ], +) +@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability") +def test_tfp_ops(op, test_values): + inputs = [as_tensor(test_value).type() for test_value in test_values] + output = op(*inputs) + + fg = FunctionGraph(inputs, [output]) + compare_jax_and_py(fg, test_values) + + def test_psi(): x = scalar("x") out = psi(x)