diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index 45b3c34c73..fc73f7e1bc 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -21,11 +21,14 @@ Sub, ) from pytensor.scalar.math import ( + BetaIncInv, Erf, Erfc, Erfcinv, Erfcx, Erfinv, + GammaIncCInv, + GammaIncInv, Iv, Ive, Log1mexp, @@ -226,6 +229,20 @@ def second(x, y): return second +@jax_funcify.register(GammaIncInv) +def jax_funcify_GammaIncInv(op, **kwargs): + gammaincinv = try_import_tfp_jax_op(op, jax_op_name="igammainv") + + return gammaincinv + + +@jax_funcify.register(GammaIncCInv) +def jax_funcify_GammaIncCInv(op, **kwargs): + gammainccinv = try_import_tfp_jax_op(op, jax_op_name="igammacinv") + + return gammainccinv + + @jax_funcify.register(Erf) def jax_funcify_Erf(op, node, **kwargs): def erf(x): @@ -250,6 +267,7 @@ def erfinv(x): return erfinv +@jax_funcify.register(BetaIncInv) @jax_funcify.register(Erfcx) @jax_funcify.register(Erfcinv) def jax_funcify_from_tfp(op, **kwargs): diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 2ab97ff122..7eba128100 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -733,6 +733,64 @@ def __hash__(self): gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") +class GammaIncInv(BinaryScalarOp): + """ + Inverse to the regularized lower incomplete gamma function. + """ + + nfunc_spec = ("scipy.special.gammaincinv", 2, 1) + + @staticmethod + def st_impl(k, x): + return scipy.special.gammaincinv(k, x) + + def impl(self, k, x): + return GammaIncInv.st_impl(k, x) + + def grad(self, inputs, grads): + (k, x) = inputs + (gz,) = grads + return [ + grad_not_implemented(self, 0, k), + gz * exp(gammaincinv(k, x)) * gamma(k) * (gammaincinv(k, x) ** (1 - k)), + ] + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +gammaincinv = GammaIncInv(upgrade_to_float, name="gammaincinv") + + +class GammaIncCInv(BinaryScalarOp): + """ + Inverse to the regularized upper incomplete gamma function. + """ + + nfunc_spec = ("scipy.special.gammainccinv", 2, 1) + + @staticmethod + def st_impl(k, x): + return scipy.special.gammainccinv(k, x) + + def impl(self, k, x): + return GammaIncCInv.st_impl(k, x) + + def grad(self, inputs, grads): + (k, x) = inputs + (gz,) = grads + return [ + grad_not_implemented(self, 0, k), + gz * -exp(gammainccinv(k, x)) * gamma(k) * (gammainccinv(k, x) ** (1 - k)), + ] + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +gammainccinv = GammaIncCInv(upgrade_to_float, name="gammainccinv") + + def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=ScalarLoop): init = [as_scalar(x) if x is not None else None for x in init] constant = [as_scalar(x) for x in constant] @@ -1648,6 +1706,43 @@ def inner_loop( return grad +class BetaIncInv(ScalarOp): + """ + Inverse of the regularized incomplete beta function. + """ + + nfunc_spec = ("scipy.special.betaincinv", 3, 1) + + def impl(self, a, b, x): + return scipy.special.betaincinv(a, b, x) + + def grad(self, inputs, grads): + (a, b, x) = inputs + (gz,) = grads + return [ + grad_not_implemented(self, 0, a), + grad_not_implemented(self, 0, b), + gz + * exp(betaln(a, b)) + * ((1 - betaincinv(a, b, x)) ** (1 - b)) + * (betaincinv(a, b, x) ** (1 - a)), + ] + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +betaincinv = BetaIncInv(upgrade_to_float_no_complex, name="betaincinv") + + +def betaln(a, b): + """ + Beta function from gamma function. + """ + + return gammaln(a) + gammaln(b) - gammaln(a + b) + + class Hyp2F1(ScalarOp): """ Gaussian hypergeometric function ``2F1(a, b; c; z)``. diff --git a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py index 5f78f54615..73b3942327 100644 --- a/pytensor/tensor/inplace.py +++ b/pytensor/tensor/inplace.py @@ -283,6 +283,16 @@ def gammal_inplace(k, x): """lower incomplete gamma function""" +@scalar_elemwise +def gammaincinv_inplace(k, x): + """Inverse to the regularized lower incomplete gamma function""" + + +@scalar_elemwise +def gammainccinv_inplace(k, x): + """Inverse of the regularized upper incomplete gamma function""" + + @scalar_elemwise def j0_inplace(x): """Bessel function of the first kind of order 0.""" @@ -338,6 +348,11 @@ def betainc_inplace(a, b, x): """Regularized incomplete beta function""" +@scalar_elemwise +def betaincinv_inplace(a, b, x): + """Inverse of the regularized incomplete beta function""" + + @scalar_elemwise def second_inplace(a): """Fill `a` with `b`""" diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index efc7c20c45..45c926f501 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1385,6 +1385,16 @@ def gammal(k, x): """Lower incomplete gamma function.""" +@scalar_elemwise +def gammaincinv(k, x): + """Inverse to the regularized lower incomplete gamma function""" + + +@scalar_elemwise +def gammainccinv(k, x): + """Inverse of the regularized upper incomplete gamma function""" + + @scalar_elemwise def hyp2f1(a, b, c, z): """Gaussian hypergeometric function.""" @@ -1451,6 +1461,11 @@ def betainc(a, b, x): """Regularized incomplete beta function""" +@scalar_elemwise +def betaincinv(a, b, x): + """Inverse of the regularized incomplete beta function""" + + @scalar_elemwise def real(z): """Return real component of complex-valued tensor `z`.""" @@ -3044,6 +3059,8 @@ def vectorize_node_to_matmul(op, node, batched_x, batched_y): "gammaincc", "gammau", "gammal", + "gammaincinv", + "gammainccinv", "j0", "j1", "jv", @@ -3057,6 +3074,7 @@ def vectorize_node_to_matmul(op, node, batched_x, batched_y): "log1pexp", "log1mexp", "betainc", + "betaincinv", "real", "imag", "angle", diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index c57e061602..7b5e52d637 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -6,7 +6,7 @@ from pytensor.graph.basic import Apply from pytensor.link.c.op import COp from pytensor.tensor.basic import as_tensor_variable -from pytensor.tensor.math import gamma, neg, sum +from pytensor.tensor.math import gamma, gammaln, neg, sum class SoftmaxGrad(COp): @@ -752,9 +752,27 @@ def factorial(n): return gamma(n + 1) +def beta(a, b): + """ + Beta function. + + """ + return (gamma(a) * gamma(b)) / gamma(a + b) + + +def betaln(a, b): + """ + Log beta function. + + """ + return gammaln(a) + gammaln(b) - gammaln(a + b) + + __all__ = [ "softmax", "log_softmax", "poch", "factorial", + "beta", + "betaln", ] diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 68f5a0bd6c..0469301791 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -11,12 +11,15 @@ from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import ( + betaincinv, cosh, erf, erfc, erfcinv, erfcx, erfinv, + gammainccinv, + gammaincinv, iv, log, log1mexp, @@ -165,6 +168,38 @@ def test_tfp_ops(op, test_values): compare_jax_and_py(fg, test_values) +def test_betaincinv(): + a = vector("a", dtype="float64") + b = vector("b", dtype="float64") + x = vector("x", dtype="float64") + out = betaincinv(a, b, x) + fg = FunctionGraph([a, b, x], [out]) + compare_jax_and_py( + fg, + [ + np.array([5.5, 7.0]), + np.array([5.5, 7.0]), + np.array([0.25, 0.7]), + ], + ) + + +def test_gammaincinv(): + k = vector("k", dtype="float64") + x = vector("x", dtype="float64") + out = gammaincinv(k, x) + fg = FunctionGraph([k, x], [out]) + compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])]) + + +def test_gammainccinv(): + k = vector("k", dtype="float64") + x = vector("x", dtype="float64") + out = gammainccinv(k, x) + fg = FunctionGraph([k, x], [out]) + compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])]) + + def test_psi(): x = scalar("x") out = psi(x) diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index d98daccf1d..3178d53b4e 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -69,6 +69,8 @@ def scipy_special_gammal(k, x): expected_gammaincc = scipy.special.gammaincc expected_gammau = scipy_special_gammau expected_gammal = scipy_special_gammal +expected_gammaincinv = scipy.special.gammaincinv +expected_gammainccinv = scipy.special.gammainccinv expected_j0 = scipy.special.j0 expected_j1 = scipy.special.j1 expected_jv = scipy.special.jv @@ -79,6 +81,7 @@ def scipy_special_gammal(k, x): expected_erfcx = scipy.special.erfcx expected_sigmoid = scipy.special.expit expected_hyp2f1 = scipy.special.hyp2f1 +expected_betaincinv = scipy.special.betaincinv TestErfBroadcast = makeBroadcastTester( op=pt.erf, @@ -484,6 +487,49 @@ def test_gammaincc_ddk_performance(benchmark): inplace=True, ) +rng = np.random.default_rng(seed=utt.fetch_seed()) +_good_broadcast_binary_gamma = dict( + normal=( + random_ranged(0, 100, (2, 3), rng=rng), + random_ranged(0, 1, (2, 3), rng=rng), + ), + empty=(np.asarray([], dtype=config.floatX), np.asarray([], dtype=config.floatX)), +) + +TestGammaIncInvBroadcast = makeBroadcastTester( + op=pt.gammaincinv, + expected=expected_gammaincinv, + good=_good_broadcast_binary_gamma, + eps=2e-8, + mode=mode_no_scipy, +) + +TestGammaIncInvInplaceBroadcast = makeBroadcastTester( + op=inplace.gammaincinv_inplace, + expected=expected_gammaincinv, + good=_good_broadcast_binary_gamma, + eps=2e-8, + mode=mode_no_scipy, + inplace=True, +) + +TestGammaInccInvBroadcast = makeBroadcastTester( + op=pt.gammainccinv, + expected=expected_gammainccinv, + good=_good_broadcast_binary_gamma, + eps=2e-8, + mode=mode_no_scipy, +) + +TestGammaInccInvInplaceBroadcast = makeBroadcastTester( + op=inplace.gammainccinv_inplace, + expected=expected_gammainccinv, + good=_good_broadcast_binary_gamma, + eps=2e-8, + mode=mode_no_scipy, + inplace=True, +) + rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_unary_bessel = dict( normal=(random_ranged(-10, 10, (2, 3), rng=rng),), @@ -880,6 +926,27 @@ def test_beta_inc_stan_grad_combined(self): ) +_good_broadcast_ternary_betaincinv = dict( + normal=( + random_ranged(0, 1000, (2, 3)), + random_ranged(0, 1000, (2, 3)), + random_ranged(0, 1, (2, 3)), + ), +) + +TestBetaincinvBroadcast = makeBroadcastTester( + op=pt.betaincinv, + expected=scipy.special.betaincinv, + good=_good_broadcast_ternary_betaincinv, +) + +TestBetaincinvInplaceBroadcast = makeBroadcastTester( + op=inplace.betaincinv_inplace, + expected=scipy.special.betaincinv, + good=_good_broadcast_ternary_betaincinv, + inplace=True, +) + _good_broadcast_quaternary_hyp2f1 = dict( normal=( random_ranged(0, 20, (2, 3)), diff --git a/tests/tensor/test_special.py b/tests/tensor/test_special.py index 17a9c05eff..a7448f1d86 100644 --- a/tests/tensor/test_special.py +++ b/tests/tensor/test_special.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from scipy.special import beta as scipy_beta from scipy.special import factorial as scipy_factorial from scipy.special import log_softmax as scipy_log_softmax from scipy.special import poch as scipy_poch @@ -11,6 +12,8 @@ LogSoftmax, Softmax, SoftmaxGrad, + beta, + betaln, factorial, log_softmax, poch, @@ -171,3 +174,29 @@ def test_factorial(n): np.testing.assert_allclose( actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5 ) + + +def test_beta(): + _a, _b = vectors("a", "b") + actual_fn = function([_a, _b], beta(_a, _b)) + + a = random_ranged(0, 5, (2,)) + b = random_ranged(0, 5, (2,)) + actual = actual_fn(a, b) + expected = scipy_beta(a, b) + np.testing.assert_allclose( + actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5 + ) + + +def test_betaln(): + _a, _b = vectors("a", "b") + actual_fn = function([_a, _b], betaln(_a, _b)) + + a = random_ranged(0, 5, (2,)) + b = random_ranged(0, 5, (2,)) + actual = np.exp(actual_fn(a, b)) + expected = scipy_beta(a, b) + np.testing.assert_allclose( + actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5 + )