From be3e038ec899a6f0f6ba15c0c180c6d3ebfb8d0b Mon Sep 17 00:00:00 2001 From: amyoshino Date: Wed, 15 Nov 2023 21:30:02 -0500 Subject: [PATCH 01/15] add betaincinv and gammaincinv functions --- pytensor/scalar/math.py | 37 +++++++++++++++++++++++++++++++++++++ tests/scalar/test_math.py | 15 +++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index f87f42066c..d67c148881 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -683,6 +683,26 @@ def __hash__(self): gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") +class GammaIncInv(BinaryScalarOp): + """ + Inverse to the regularized lower incomplete gamma function. + """ + + nfunc_spec = ("scipy.special.gammaincinv", 1, 1) + + @staticmethod + def st_impl(k, x): + return scipy.special.gammaincinv(k, x) + + def impl(self, k, x): + return GammaIncInv.st_impl(k, x) + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +gammaincinv = GammaIncInv(upgrade_to_float, name="gammaincinv") + def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=ScalarLoop): init = [as_scalar(x) if x is not None else None for x in init] @@ -1567,6 +1587,23 @@ def inner_loop( ) return grad +class BetaIncInv(ScalarOp): + """ + Inverse of the regularized incomplete beta function. + """ + + nin = 3 + nfunc_spec = ("scipy.special.betaincinv", 1, 1) + + def impl(self, a, b, x): + return scipy.special.betaincinv(a, b, x) + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +betaincinv = BetaIncInv(upgrade_to_float_no_complex, name="betaincinv") + class Hyp2F1(ScalarOp): """ diff --git a/tests/scalar/test_math.py b/tests/scalar/test_math.py index 1998ed5fa5..0aaa9ef73b 100644 --- a/tests/scalar/test_math.py +++ b/tests/scalar/test_math.py @@ -14,8 +14,10 @@ from pytensor.scalar.math import ( betainc, betainc_grad, + betaincinv, gammainc, gammaincc, + gammaincinv, gammal, gammau, hyp2f1, @@ -58,6 +60,13 @@ def test_gammaincc_nan_c(): assert np.isnan(test_func(1, -1)) assert np.isnan(test_func(-1, -1)) +def test_gammaincinv_python(): + x1 = at.dscalar() + x2 = at.dscalar() + y = gammaincinv(x1, x2) + test_func = function([x1, x2], y, mode=Mode("py")) + assert np.isclose(test_func(1, 0.2), sp.gammaincinv(1, 0.2)) + def test_gammal_nan_c(): x1 = at.dscalar() @@ -96,6 +105,12 @@ def test_betainc_derivative_nan(): assert np.isnan(test_func(1, -1, 1)) assert np.isnan(test_func(1, 1, -1)) +def test_betaincinv(): + a, b, x = at.scalars("a", "b", "x") + res = betaincinv(a, b, x) + test_func = function([a, b, x], res, mode=Mode("py")) + assert np.isclose(test_func(15, 10, 0.7), sp.betaincinv(15, 10, 0.7)) + @pytest.mark.parametrize( "op, scalar_loop_grads", From ff9049dc7d935f230d658fb432f6f8c988b55826 Mon Sep 17 00:00:00 2001 From: amyoshino Date: Wed, 15 Nov 2023 21:32:18 -0500 Subject: [PATCH 02/15] add betaincinv and gammaincinv functions --- pytensor/scalar/math.py | 2 ++ tests/scalar/test_math.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index d67c148881..6406dba670 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -683,6 +683,7 @@ def __hash__(self): gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") + class GammaIncInv(BinaryScalarOp): """ Inverse to the regularized lower incomplete gamma function. @@ -1587,6 +1588,7 @@ def inner_loop( ) return grad + class BetaIncInv(ScalarOp): """ Inverse of the regularized incomplete beta function. diff --git a/tests/scalar/test_math.py b/tests/scalar/test_math.py index 0aaa9ef73b..62006dcbbc 100644 --- a/tests/scalar/test_math.py +++ b/tests/scalar/test_math.py @@ -60,6 +60,7 @@ def test_gammaincc_nan_c(): assert np.isnan(test_func(1, -1)) assert np.isnan(test_func(-1, -1)) + def test_gammaincinv_python(): x1 = at.dscalar() x2 = at.dscalar() @@ -105,6 +106,7 @@ def test_betainc_derivative_nan(): assert np.isnan(test_func(1, -1, 1)) assert np.isnan(test_func(1, 1, -1)) + def test_betaincinv(): a, b, x = at.scalars("a", "b", "x") res = betaincinv(a, b, x) From 753d1500b120d2d4b0f9f86af17253c9c2c6f41f Mon Sep 17 00:00:00 2001 From: amyoshino Date: Wed, 15 Nov 2023 22:21:32 -0500 Subject: [PATCH 03/15] removing nin variable --- pytensor/scalar/math.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 6406dba670..a9f3e9d1c5 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -1594,7 +1594,6 @@ class BetaIncInv(ScalarOp): Inverse of the regularized incomplete beta function. """ - nin = 3 nfunc_spec = ("scipy.special.betaincinv", 1, 1) def impl(self, a, b, x): From c3b9f14063abe97ee9aa4ef137f9fc5421686047 Mon Sep 17 00:00:00 2001 From: "Adriano M. Yoshino" Date: Thu, 16 Nov 2023 09:11:10 -0500 Subject: [PATCH 04/15] Update pytensor/scalar/math.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/scalar/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index a9f3e9d1c5..a4098513f8 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -1594,7 +1594,7 @@ class BetaIncInv(ScalarOp): Inverse of the regularized incomplete beta function. """ - nfunc_spec = ("scipy.special.betaincinv", 1, 1) + nfunc_spec = ("scipy.special.betaincinv", 2, 1) def impl(self, a, b, x): return scipy.special.betaincinv(a, b, x) From dd1b1e9f8ae5eac2705fa52bfc062f83301bd1eb Mon Sep 17 00:00:00 2001 From: amyoshino Date: Sun, 26 Nov 2023 12:02:08 -0300 Subject: [PATCH 05/15] add first derivative and tests --- pytensor/scalar/math.py | 55 +++++++++++++++++++++++++++++++++ pytensor/tensor/inplace.py | 15 +++++++++ pytensor/tensor/math.py | 18 +++++++++++ tests/scalar/test_math.py | 17 ---------- tests/tensor/test_math_scipy.py | 50 ++++++++++++++++++++++++++++++ 5 files changed, 138 insertions(+), 17 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index a4098513f8..c784950f39 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -698,6 +698,17 @@ def st_impl(k, x): def impl(self, k, x): return GammaIncInv.st_impl(k, x) + def grad(self, inputs, grads): + (k, x) = inputs + (gz,) = grads + return [ + grad_not_implemented(self, 0, k), + gz + * exp(scipy.special.gammaincinv(k, x)) + * scipy.special.gamma(k) + * (scipy.special.gammaincinv(k, x) ** (1 - k)), + ] + def c_code(self, *args, **kwargs): raise NotImplementedError() @@ -705,6 +716,38 @@ def c_code(self, *args, **kwargs): gammaincinv = GammaIncInv(upgrade_to_float, name="gammaincinv") +class GammaIncCInv(BinaryScalarOp): + """ + Inverse to the regularized upper incomplete gamma function. + """ + + nfunc_spec = ("scipy.special.gammaincinv", 1, 1) + + @staticmethod + def st_impl(k, x): + return scipy.special.gammainccinv(k, x) + + def impl(self, k, x): + return GammaIncInv.st_impl(k, x) + + def grad(self, inputs, grads): + (k, x) = inputs + (gz,) = grads + return [ + grad_not_implemented(self, 0, k), + gz + * -exp(scipy.special.gammainccinv(k, x)) + * scipy.special.gamma(k) + * (scipy.special.gammainccinv(k, x) ** (1 - k)), + ] + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +gammainccinv = GammaIncCInv(upgrade_to_float, name="gammainccinv") + + def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=ScalarLoop): init = [as_scalar(x) if x is not None else None for x in init] constant = [as_scalar(x) for x in constant] @@ -1599,6 +1642,18 @@ class BetaIncInv(ScalarOp): def impl(self, a, b, x): return scipy.special.betaincinv(a, b, x) + def grad(self, inputs, grads): + (a, b, x) = inputs + (gz,) = grads + return [ + grad_not_implemented(self, 0, a), + grad_not_implemented(self, 0, b), + gz + * scipy.special.beta(a, b) + * ((1 - scipy.special.betaincinv(a, b, x)) ** (1 - b)) + * (scipy.special.betaincinv(a, b, x) ** (1 - a)), + ] + def c_code(self, *args, **kwargs): raise NotImplementedError() diff --git a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py index afb7f7ac7c..7addab69f6 100644 --- a/pytensor/tensor/inplace.py +++ b/pytensor/tensor/inplace.py @@ -283,6 +283,16 @@ def gammal_inplace(k, x): """lower incomplete gamma function""" +@scalar_elemwise +def gammaincinv_inplace(k, x): + """Inverse to the regularized lower incomplete gamma function""" + + +@scalar_elemwise +def gammainccinv_inplace(k, x): + """Inverse of the regularized upper incomplete gamma function""" + + @scalar_elemwise def j0_inplace(x): """Bessel function of the first kind of order 0.""" @@ -333,6 +343,11 @@ def betainc_inplace(a, b, x): """Regularized incomplete beta function""" +@scalar_elemwise +def betaincinv_inplace(a, b, x): + """Inverse of the regularized incomplete beta function""" + + @scalar_elemwise def second_inplace(a): """Fill `a` with `b`""" diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 56777eeb67..8a5ac8e517 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1394,6 +1394,16 @@ def gammal(k, x): """Lower incomplete gamma function.""" +@scalar_elemwise +def gammaincinv(k, x): + """Inverse to the regularized lower incomplete gamma function""" + + +@scalar_elemwise +def gammainccinv(k, x): + """Inverse of the regularized upper incomplete gamma function""" + + @scalar_elemwise def hyp2f1(a, b, c, z): """Gaussian hypergeometric function.""" @@ -1455,6 +1465,11 @@ def betainc(a, b, x): """Regularized incomplete beta function""" +@scalar_elemwise +def betaincinv(a, b, x): + """Inverse of the regularized incomplete beta function""" + + @scalar_elemwise def real(z): """Return real component of complex-valued tensor `z`.""" @@ -3013,6 +3028,8 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None "gammaincc", "gammau", "gammal", + "gammaincinv", + "gammainccinv", "j0", "j1", "jv", @@ -3025,6 +3042,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None "log1pexp", "log1mexp", "betainc", + "betaincinv", "real", "imag", "angle", diff --git a/tests/scalar/test_math.py b/tests/scalar/test_math.py index 62006dcbbc..1998ed5fa5 100644 --- a/tests/scalar/test_math.py +++ b/tests/scalar/test_math.py @@ -14,10 +14,8 @@ from pytensor.scalar.math import ( betainc, betainc_grad, - betaincinv, gammainc, gammaincc, - gammaincinv, gammal, gammau, hyp2f1, @@ -61,14 +59,6 @@ def test_gammaincc_nan_c(): assert np.isnan(test_func(-1, -1)) -def test_gammaincinv_python(): - x1 = at.dscalar() - x2 = at.dscalar() - y = gammaincinv(x1, x2) - test_func = function([x1, x2], y, mode=Mode("py")) - assert np.isclose(test_func(1, 0.2), sp.gammaincinv(1, 0.2)) - - def test_gammal_nan_c(): x1 = at.dscalar() x2 = at.dscalar() @@ -107,13 +97,6 @@ def test_betainc_derivative_nan(): assert np.isnan(test_func(1, 1, -1)) -def test_betaincinv(): - a, b, x = at.scalars("a", "b", "x") - res = betaincinv(a, b, x) - test_func = function([a, b, x], res, mode=Mode("py")) - assert np.isclose(test_func(15, 10, 0.7), sp.betaincinv(15, 10, 0.7)) - - @pytest.mark.parametrize( "op, scalar_loop_grads", [ diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 11e0e1730a..0e453a1a1e 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -69,6 +69,8 @@ def scipy_special_gammal(k, x): expected_gammaincc = scipy.special.gammaincc expected_gammau = scipy_special_gammau expected_gammal = scipy_special_gammal +expected_gammaincinv = scipy.special.gammaincinv +expected_gammainccinv = scipy.special.gammainccinv expected_j0 = scipy.special.j0 expected_j1 = scipy.special.j1 expected_jv = scipy.special.jv @@ -78,6 +80,7 @@ def scipy_special_gammal(k, x): expected_erfcx = scipy.special.erfcx expected_sigmoid = scipy.special.expit expected_hyp2f1 = scipy.special.hyp2f1 +expected_betaincinv = scipy.special.betaincinv TestErfBroadcast = makeBroadcastTester( op=at.erf, @@ -483,6 +486,53 @@ def test_gammaincc_ddk_performance(benchmark): inplace=True, ) +rng = np.random.default_rng(seed=utt.fetch_seed()) +_good_broadcast_binary_gamma = dict( + normal=( + random_ranged(1e-2, 1, (2, 3), rng=rng), + random_ranged(1e-2, 1, (2, 3), rng=rng), + ), + empty=(np.asarray([], dtype=config.floatX), np.asarray([], dtype=config.floatX)), +) + +_good_broadcast_binary_gamma_grad = dict(normal=(random_ranged(-10.0, 10.0, (2, 3)),)) + +TestGammaIncInvBroadcast = makeBroadcastTester( + op=at.gammaincinv, + expected=expected_gammaincinv, + good=_good_broadcast_binary_gamma, + grad=_good_broadcast_binary_gamma_grad, + eps=2e-8, + mode=mode_no_scipy, +) + +TestGammaIncInvInplaceBroadcast = makeBroadcastTester( + op=inplace.gammaincinv_inplace, + expected=expected_gammaincinv, + good=_good_broadcast_binary_gamma, + eps=2e-8, + mode=mode_no_scipy, + inplace=True, +) + +TestGammaInccInvBroadcast = makeBroadcastTester( + op=at.gammainccinv, + expected=expected_gammainccinv, + good=_good_broadcast_binary_gamma, + grad=_good_broadcast_binary_gamma_grad, + eps=2e-8, + mode=mode_no_scipy, +) + +TestGammaInccInvInplaceBroadcast = makeBroadcastTester( + op=inplace.gammainccinv_inplace, + expected=expected_gammainccinv, + good=_good_broadcast_binary_gamma, + eps=2e-8, + mode=mode_no_scipy, + inplace=True, +) + rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_unary_bessel = dict( normal=(random_ranged(-10, 10, (2, 3), rng=rng),), From 7125c2e9c598914cfb77b6f1cc50cd1ccb513c21 Mon Sep 17 00:00:00 2001 From: amyoshino Date: Mon, 18 Dec 2023 09:55:33 -0300 Subject: [PATCH 06/15] add beta function, tests and jax ops --- pytensor/link/jax/dispatch/scalar.py | 18 +++++++ pytensor/scalar/math.py | 49 +++++++++++------ pytensor/tensor/inplace.py | 5 ++ pytensor/tensor/math.py | 6 +++ tests/link/jax/test_scalar.py | 6 +++ tests/tensor/test_math_scipy.py | 78 +++++++++++++++++++++++++--- 6 files changed, 140 insertions(+), 22 deletions(-) diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index 45b3c34c73..fc73f7e1bc 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -21,11 +21,14 @@ Sub, ) from pytensor.scalar.math import ( + BetaIncInv, Erf, Erfc, Erfcinv, Erfcx, Erfinv, + GammaIncCInv, + GammaIncInv, Iv, Ive, Log1mexp, @@ -226,6 +229,20 @@ def second(x, y): return second +@jax_funcify.register(GammaIncInv) +def jax_funcify_GammaIncInv(op, **kwargs): + gammaincinv = try_import_tfp_jax_op(op, jax_op_name="igammainv") + + return gammaincinv + + +@jax_funcify.register(GammaIncCInv) +def jax_funcify_GammaIncCInv(op, **kwargs): + gammainccinv = try_import_tfp_jax_op(op, jax_op_name="igammacinv") + + return gammainccinv + + @jax_funcify.register(Erf) def jax_funcify_Erf(op, node, **kwargs): def erf(x): @@ -250,6 +267,7 @@ def erfinv(x): return erfinv +@jax_funcify.register(BetaIncInv) @jax_funcify.register(Erfcx) @jax_funcify.register(Erfcinv) def jax_funcify_from_tfp(op, **kwargs): diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 8a5db49ef2..b71b8b11f1 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -738,7 +738,7 @@ class GammaIncInv(BinaryScalarOp): Inverse to the regularized lower incomplete gamma function. """ - nfunc_spec = ("scipy.special.gammaincinv", 1, 1) + nfunc_spec = ("scipy.special.gammaincinv", 2, 1) @staticmethod def st_impl(k, x): @@ -752,10 +752,7 @@ def grad(self, inputs, grads): (gz,) = grads return [ grad_not_implemented(self, 0, k), - gz - * exp(scipy.special.gammaincinv(k, x)) - * scipy.special.gamma(k) - * (scipy.special.gammaincinv(k, x) ** (1 - k)), + gz * exp(gammaincinv(k, x)) * gamma(k) * (gammaincinv(k, x) ** (1 - k)), ] def c_code(self, *args, **kwargs): @@ -770,24 +767,21 @@ class GammaIncCInv(BinaryScalarOp): Inverse to the regularized upper incomplete gamma function. """ - nfunc_spec = ("scipy.special.gammaincinv", 1, 1) + nfunc_spec = ("scipy.special.gammainccinv", 2, 1) @staticmethod def st_impl(k, x): return scipy.special.gammainccinv(k, x) def impl(self, k, x): - return GammaIncInv.st_impl(k, x) + return GammaIncCInv.st_impl(k, x) def grad(self, inputs, grads): (k, x) = inputs (gz,) = grads return [ grad_not_implemented(self, 0, k), - gz - * -exp(scipy.special.gammainccinv(k, x)) - * scipy.special.gamma(k) - * (scipy.special.gammainccinv(k, x) ** (1 - k)), + gz * -exp(gammainccinv(k, x)) * gamma(k) * (gammainccinv(k, x) ** (1 - k)), ] def c_code(self, *args, **kwargs): @@ -1712,12 +1706,37 @@ def inner_loop( return grad +class Beta(BinaryScalarOp): + """ + Beta function. + """ + + nfunc_spec = ("scipy.special.beta", 2, 1) + + def impl(self, a, b): + return scipy.special.beta(a, b) + + def grad(self, inputs, grads): + (a, b) = inputs + (gz,) = grads + return [ + gz * beta(a, b) * (polygamma(0, a) - polygamma(0, a + b)), + gz * beta(a, b) * (polygamma(0, b) - polygamma(0, a + b)), + ] + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +beta = Beta(upgrade_to_float_no_complex, name="beta") + + class BetaIncInv(ScalarOp): """ Inverse of the regularized incomplete beta function. """ - nfunc_spec = ("scipy.special.betaincinv", 2, 1) + nfunc_spec = ("scipy.special.betaincinv", 3, 1) def impl(self, a, b, x): return scipy.special.betaincinv(a, b, x) @@ -1729,9 +1748,9 @@ def grad(self, inputs, grads): grad_not_implemented(self, 0, a), grad_not_implemented(self, 0, b), gz - * scipy.special.beta(a, b) - * ((1 - scipy.special.betaincinv(a, b, x)) ** (1 - b)) - * (scipy.special.betaincinv(a, b, x) ** (1 - a)), + * beta(a, b) + * ((1 - betaincinv(a, b, x)) ** (1 - b)) + * (betaincinv(a, b, x) ** (1 - a)), ] def c_code(self, *args, **kwargs): diff --git a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py index 73b3942327..a403de48ab 100644 --- a/pytensor/tensor/inplace.py +++ b/pytensor/tensor/inplace.py @@ -353,6 +353,11 @@ def betaincinv_inplace(a, b, x): """Inverse of the regularized incomplete beta function""" +@scalar_elemwise +def beta_inplace(a, b): + """Beta function""" + + @scalar_elemwise def second_inplace(a): """Fill `a` with `b`""" diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index d414a2e646..79202ed682 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1481,6 +1481,11 @@ def betaincinv(a, b, x): """Inverse of the regularized incomplete beta function""" +@scalar_elemwise +def beta(a, b): + """Beta function""" + + @scalar_elemwise def real(z): """Return real component of complex-valued tensor `z`.""" @@ -3069,6 +3074,7 @@ def vectorize_node_to_matmul(op, node, batched_x, batched_y): "log1mexp", "betainc", "betaincinv", + "beta", "real", "imag", "angle", diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 68f5a0bd6c..ff6ca697fb 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -11,12 +11,15 @@ from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import ( + betaincinv, cosh, erf, erfc, erfcinv, erfcx, erfinv, + gammainccinv, + gammaincinv, iv, log, log1mexp, @@ -151,8 +154,11 @@ def test_erfinv(): @pytest.mark.parametrize( "op, test_values", [ + (betaincinv, (3.0, 5.5, 0.7)), (erfcx, (0.7,)), (erfcinv, (0.7,)), + (gammaincinv, (5.5, 0.7)), + (gammainccinv, (5.5, 0.7)), (iv, (0.3, 0.7)), ], ) diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index b73bc1b380..5dc5cc0a6a 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -82,6 +82,7 @@ def scipy_special_gammal(k, x): expected_sigmoid = scipy.special.expit expected_hyp2f1 = scipy.special.hyp2f1 expected_betaincinv = scipy.special.betaincinv +expected_beta = scipy.special.beta TestErfBroadcast = makeBroadcastTester( op=pt.erf, @@ -490,19 +491,20 @@ def test_gammaincc_ddk_performance(benchmark): rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_binary_gamma = dict( normal=( - random_ranged(1e-2, 1, (2, 3), rng=rng), - random_ranged(1e-2, 1, (2, 3), rng=rng), + random_ranged(0, 100, (2, 3), rng=rng), + random_ranged(0, 1, (2, 3), rng=rng), ), empty=(np.asarray([], dtype=config.floatX), np.asarray([], dtype=config.floatX)), ) -_good_broadcast_binary_gamma_grad = dict(normal=(random_ranged(-10.0, 10.0, (2, 3)),)) +_good_broadcast_binary_gamma_grad = dict( + normal=(random_ranged(0, 1000.0, (2, 3)), random_ranged(0.0, 1.0, (2, 3))) +) TestGammaIncInvBroadcast = makeBroadcastTester( - op=at.gammaincinv, + op=pt.gammaincinv, expected=expected_gammaincinv, good=_good_broadcast_binary_gamma, - grad=_good_broadcast_binary_gamma_grad, eps=2e-8, mode=mode_no_scipy, ) @@ -517,10 +519,9 @@ def test_gammaincc_ddk_performance(benchmark): ) TestGammaInccInvBroadcast = makeBroadcastTester( - op=at.gammainccinv, + op=pt.gammainccinv, expected=expected_gammainccinv, good=_good_broadcast_binary_gamma, - grad=_good_broadcast_binary_gamma_grad, eps=2e-8, mode=mode_no_scipy, ) @@ -930,6 +931,69 @@ def test_beta_inc_stan_grad_combined(self): ) +rng = np.random.default_rng(seed=utt.fetch_seed()) +_good_broadcast_binary_beta = dict( + normal=( + random_ranged(1e-2, 100, (2, 3), rng=rng), + random_ranged(1e-2, 100, (2, 3), rng=rng), + ), + integers=( + integers_ranged(1, 100, (2, 3), rng=rng), + integers_ranged(1, 100, (2, 3), rng=rng), + ), + uint8=( + integers_ranged(1, 100, (2, 3), rng=rng).astype("uint8"), + integers_ranged(1, 100, (2, 3), rng=rng).astype("uint8"), + ), + uint16=( + integers_ranged(1, 100, (2, 3), rng=rng).astype("uint16"), + integers_ranged(1, 100, (2, 3), rng=rng).astype("uint16"), + ), +) + +_grad_broadcast_binary_beta = dict( + normal=( + random_ranged(1e-2, 100, (2, 3), rng=rng), + random_ranged(1e-2, 100, (2, 3), rng=rng), + ) +) + +TestBetaBroadcast = makeBroadcastTester( + op=pt.beta, + expected=scipy.special.beta, + good=_good_broadcast_binary_beta, + grad=_grad_broadcast_binary_beta, +) + +TestBetaInplaceBroadcast = makeBroadcastTester( + op=inplace.beta_inplace, + expected=scipy.special.beta, + good=_good_broadcast_binary_beta, + grad=_grad_broadcast_binary_beta, + inplace=True, +) + +_good_broadcast_ternary_betaincinv = dict( + normal=( + random_ranged(0, 1000, (2, 3)), + random_ranged(0, 1000, (2, 3)), + random_ranged(0, 1, (2, 3)), + ), +) + +TestBetaincinvBroadcast = makeBroadcastTester( + op=pt.betaincinv, + expected=scipy.special.betaincinv, + good=_good_broadcast_ternary_betaincinv, +) + +TestBetaincinvInplaceBroadcast = makeBroadcastTester( + op=inplace.betaincinv_inplace, + expected=scipy.special.betaincinv, + good=_good_broadcast_ternary_betaincinv, + inplace=True, +) + _good_broadcast_quaternary_hyp2f1 = dict( normal=( random_ranged(0, 20, (2, 3)), From f96eb6a0aff5ebde53f870be6446f42a9b697c13 Mon Sep 17 00:00:00 2001 From: amyoshino Date: Mon, 18 Dec 2023 10:10:01 -0300 Subject: [PATCH 07/15] remove unused test case --- tests/tensor/test_math_scipy.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 5dc5cc0a6a..43d34e2d4f 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -497,10 +497,6 @@ def test_gammaincc_ddk_performance(benchmark): empty=(np.asarray([], dtype=config.floatX), np.asarray([], dtype=config.floatX)), ) -_good_broadcast_binary_gamma_grad = dict( - normal=(random_ranged(0, 1000.0, (2, 3)), random_ranged(0.0, 1.0, (2, 3))) -) - TestGammaIncInvBroadcast = makeBroadcastTester( op=pt.gammaincinv, expected=expected_gammaincinv, From 667bb6f24bba8328a0cd8ec01089d6a2e70cda3d Mon Sep 17 00:00:00 2001 From: amyoshino Date: Mon, 18 Dec 2023 10:36:12 -0300 Subject: [PATCH 08/15] fixing jax test error --- tests/link/jax/test_scalar.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index ff6ca697fb..efee340873 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -154,11 +154,8 @@ def test_erfinv(): @pytest.mark.parametrize( "op, test_values", [ - (betaincinv, (3.0, 5.5, 0.7)), (erfcx, (0.7,)), (erfcinv, (0.7,)), - (gammaincinv, (5.5, 0.7)), - (gammainccinv, (5.5, 0.7)), (iv, (0.3, 0.7)), ], ) @@ -171,6 +168,31 @@ def test_tfp_ops(op, test_values): compare_jax_and_py(fg, test_values) +def test_betaincinv(): + a = vector("a", dtype="float32") + b = vector("b", dtype="float32") + x = vector("x", dtype="float32") + out = betaincinv(a, b, x) + fg = FunctionGraph([a, b, x], [out]) + compare_jax_and_py(fg, [np.array([3.0, 5.5, 0.7])]) + + +def test_gammaincinv(): + k = vector("k", dtype="float32") + x = vector("x", dtype="float32") + out = gammaincinv(k, x) + fg = FunctionGraph([k, x], [out]) + compare_jax_and_py(fg, [np.array([5.5, 0.7])]) + + +def test_gammainccinv(): + k = vector("k", dtype="float32") + x = vector("x", dtype="float32") + out = gammainccinv(k, x) + fg = FunctionGraph([k, x], [out]) + compare_jax_and_py(fg, [np.array([5.5, 0.7])]) + + def test_psi(): x = scalar("x") out = psi(x) From 5241d36d6fa14ebfb13cc70e0a336d4ee0cd8c52 Mon Sep 17 00:00:00 2001 From: amyoshino Date: Mon, 18 Dec 2023 11:38:11 -0300 Subject: [PATCH 09/15] fixing jax test error - using float64 --- tests/link/jax/test_scalar.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index efee340873..d26c14ba11 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -169,25 +169,25 @@ def test_tfp_ops(op, test_values): def test_betaincinv(): - a = vector("a", dtype="float32") - b = vector("b", dtype="float32") - x = vector("x", dtype="float32") + a = vector("a", dtype="float64") + b = vector("b", dtype="float64") + x = vector("x", dtype="float64") out = betaincinv(a, b, x) fg = FunctionGraph([a, b, x], [out]) compare_jax_and_py(fg, [np.array([3.0, 5.5, 0.7])]) def test_gammaincinv(): - k = vector("k", dtype="float32") - x = vector("x", dtype="float32") + k = vector("k", dtype="float64") + x = vector("x", dtype="float64") out = gammaincinv(k, x) fg = FunctionGraph([k, x], [out]) compare_jax_and_py(fg, [np.array([5.5, 0.7])]) def test_gammainccinv(): - k = vector("k", dtype="float32") - x = vector("x", dtype="float32") + k = vector("k", dtype="float64") + x = vector("x", dtype="float64") out = gammainccinv(k, x) fg = FunctionGraph([k, x], [out]) compare_jax_and_py(fg, [np.array([5.5, 0.7])]) From 14b0837c73311fb25017c5026c8c949e7387966a Mon Sep 17 00:00:00 2001 From: amyoshino Date: Mon, 18 Dec 2023 20:56:36 -0300 Subject: [PATCH 10/15] fixing jax test args error --- tests/link/jax/test_scalar.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index d26c14ba11..474a496d59 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -175,6 +175,14 @@ def test_betaincinv(): out = betaincinv(a, b, x) fg = FunctionGraph([a, b, x], [out]) compare_jax_and_py(fg, [np.array([3.0, 5.5, 0.7])]) + compare_jax_and_py( + fg, + [ + np.array([5.5, 7.0]), + np.array([5.5, 7.0]), + np.array([0.25, 0.7]), + ], + ) def test_gammaincinv(): @@ -182,7 +190,7 @@ def test_gammaincinv(): x = vector("x", dtype="float64") out = gammaincinv(k, x) fg = FunctionGraph([k, x], [out]) - compare_jax_and_py(fg, [np.array([5.5, 0.7])]) + compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])]) def test_gammainccinv(): @@ -190,7 +198,7 @@ def test_gammainccinv(): x = vector("x", dtype="float64") out = gammainccinv(k, x) fg = FunctionGraph([k, x], [out]) - compare_jax_and_py(fg, [np.array([5.5, 0.7])]) + compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])]) def test_psi(): From 8fd35e7372364a70a0304699d3ec84a66b57d6f9 Mon Sep 17 00:00:00 2001 From: amyoshino Date: Mon, 18 Dec 2023 21:10:47 -0300 Subject: [PATCH 11/15] fixing jax test args error v2 --- tests/link/jax/test_scalar.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 474a496d59..0469301791 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -174,7 +174,6 @@ def test_betaincinv(): x = vector("x", dtype="float64") out = betaincinv(a, b, x) fg = FunctionGraph([a, b, x], [out]) - compare_jax_and_py(fg, [np.array([3.0, 5.5, 0.7])]) compare_jax_and_py( fg, [ From bc715a65f5addb52b65f2af8b480e6ba8daf13c5 Mon Sep 17 00:00:00 2001 From: amyoshino Date: Sat, 30 Dec 2023 13:11:11 -0300 Subject: [PATCH 12/15] implementing beta function Op and adding it as beta and betaln in tensor.special --- pytensor/scalar/math.py | 28 ++------------------- pytensor/tensor/inplace.py | 5 ---- pytensor/tensor/math.py | 6 ----- pytensor/tensor/special.py | 25 +++++++++++++------ tests/tensor/test_math_scipy.py | 43 --------------------------------- tests/tensor/test_special.py | 29 ++++++++++++++++++++++ 6 files changed, 49 insertions(+), 87 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index b71b8b11f1..58077923c7 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -43,6 +43,7 @@ upgrade_to_float_no_complex, ) from pytensor.scalar.loop import ScalarLoop +from pytensor.tensor.special import betaln class Erf(UnaryScalarOp): @@ -1706,31 +1707,6 @@ def inner_loop( return grad -class Beta(BinaryScalarOp): - """ - Beta function. - """ - - nfunc_spec = ("scipy.special.beta", 2, 1) - - def impl(self, a, b): - return scipy.special.beta(a, b) - - def grad(self, inputs, grads): - (a, b) = inputs - (gz,) = grads - return [ - gz * beta(a, b) * (polygamma(0, a) - polygamma(0, a + b)), - gz * beta(a, b) * (polygamma(0, b) - polygamma(0, a + b)), - ] - - def c_code(self, *args, **kwargs): - raise NotImplementedError() - - -beta = Beta(upgrade_to_float_no_complex, name="beta") - - class BetaIncInv(ScalarOp): """ Inverse of the regularized incomplete beta function. @@ -1748,7 +1724,7 @@ def grad(self, inputs, grads): grad_not_implemented(self, 0, a), grad_not_implemented(self, 0, b), gz - * beta(a, b) + * exp(betaln(a, b)) * ((1 - betaincinv(a, b, x)) ** (1 - b)) * (betaincinv(a, b, x) ** (1 - a)), ] diff --git a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py index a403de48ab..73b3942327 100644 --- a/pytensor/tensor/inplace.py +++ b/pytensor/tensor/inplace.py @@ -353,11 +353,6 @@ def betaincinv_inplace(a, b, x): """Inverse of the regularized incomplete beta function""" -@scalar_elemwise -def beta_inplace(a, b): - """Beta function""" - - @scalar_elemwise def second_inplace(a): """Fill `a` with `b`""" diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 0b0f7826eb..45c926f501 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1466,11 +1466,6 @@ def betaincinv(a, b, x): """Inverse of the regularized incomplete beta function""" -@scalar_elemwise -def beta(a, b): - """Beta function""" - - @scalar_elemwise def real(z): """Return real component of complex-valued tensor `z`.""" @@ -3080,7 +3075,6 @@ def vectorize_node_to_matmul(op, node, batched_x, batched_y): "log1mexp", "betainc", "betaincinv", - "beta", "real", "imag", "angle", diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index c57e061602..21cc50f078 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -6,7 +6,7 @@ from pytensor.graph.basic import Apply from pytensor.link.c.op import COp from pytensor.tensor.basic import as_tensor_variable -from pytensor.tensor.math import gamma, neg, sum +from pytensor.tensor.math import gamma, gammaln, neg, sum class SoftmaxGrad(COp): @@ -752,9 +752,20 @@ def factorial(n): return gamma(n + 1) -__all__ = [ - "softmax", - "log_softmax", - "poch", - "factorial", -] +def beta(a, b): + """ + Beta function. + + """ + return (gamma(a) * gamma(b)) / gamma(a + b) + + +def betaln(a, b): + """ + Log beta function. + + """ + return gammaln(a) + gammaln(b) - gammaln(a + b) + + +__all__ = ["softmax", "log_softmax", "poch", "factorial", "beta", "betaln"] diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 43d34e2d4f..3178d53b4e 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -82,7 +82,6 @@ def scipy_special_gammal(k, x): expected_sigmoid = scipy.special.expit expected_hyp2f1 = scipy.special.hyp2f1 expected_betaincinv = scipy.special.betaincinv -expected_beta = scipy.special.beta TestErfBroadcast = makeBroadcastTester( op=pt.erf, @@ -927,48 +926,6 @@ def test_beta_inc_stan_grad_combined(self): ) -rng = np.random.default_rng(seed=utt.fetch_seed()) -_good_broadcast_binary_beta = dict( - normal=( - random_ranged(1e-2, 100, (2, 3), rng=rng), - random_ranged(1e-2, 100, (2, 3), rng=rng), - ), - integers=( - integers_ranged(1, 100, (2, 3), rng=rng), - integers_ranged(1, 100, (2, 3), rng=rng), - ), - uint8=( - integers_ranged(1, 100, (2, 3), rng=rng).astype("uint8"), - integers_ranged(1, 100, (2, 3), rng=rng).astype("uint8"), - ), - uint16=( - integers_ranged(1, 100, (2, 3), rng=rng).astype("uint16"), - integers_ranged(1, 100, (2, 3), rng=rng).astype("uint16"), - ), -) - -_grad_broadcast_binary_beta = dict( - normal=( - random_ranged(1e-2, 100, (2, 3), rng=rng), - random_ranged(1e-2, 100, (2, 3), rng=rng), - ) -) - -TestBetaBroadcast = makeBroadcastTester( - op=pt.beta, - expected=scipy.special.beta, - good=_good_broadcast_binary_beta, - grad=_grad_broadcast_binary_beta, -) - -TestBetaInplaceBroadcast = makeBroadcastTester( - op=inplace.beta_inplace, - expected=scipy.special.beta, - good=_good_broadcast_binary_beta, - grad=_grad_broadcast_binary_beta, - inplace=True, -) - _good_broadcast_ternary_betaincinv = dict( normal=( random_ranged(0, 1000, (2, 3)), diff --git a/tests/tensor/test_special.py b/tests/tensor/test_special.py index 17a9c05eff..24b624a8fb 100644 --- a/tests/tensor/test_special.py +++ b/tests/tensor/test_special.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from scipy.special import beta as scipy_beta from scipy.special import factorial as scipy_factorial from scipy.special import log_softmax as scipy_log_softmax from scipy.special import poch as scipy_poch @@ -11,6 +12,8 @@ LogSoftmax, Softmax, SoftmaxGrad, + beta, + betaln, factorial, log_softmax, poch, @@ -171,3 +174,29 @@ def test_factorial(n): np.testing.assert_allclose( actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5 ) + + +def test_beta(): + _a, _b = vectors("a", "b") + actual_fn = function([_a, _b], beta(_a, _b)) + + a = random_ranged(0, 5, (2,)) + b = random_ranged(0, 5, (2,)) + actual = actual_fn(a, b) + expected = scipy_beta(a, b) + np.testing.assert_allclose( + actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5 + ) + + +def test_betaln(): + _a, _b = vectors("a", "b") + actual_fn = function([_a, _b], betaln(_a, _b)) + + a = random_ranged(0, 5, (2,)) + b = random_ranged(0, 5, (2,)) + actual = actual_fn(a, b) + expected = np.exp(scipy_beta(a, b)) + np.testing.assert_allclose( + actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5 + ) From 9ad16af34c57de49c4d9cadca47eeca10f68028f Mon Sep 17 00:00:00 2001 From: amyoshino Date: Sat, 30 Dec 2023 14:32:37 -0300 Subject: [PATCH 13/15] fix betaln test and circular import error --- pytensor/scalar/math.py | 9 ++++++++- tests/tensor/test_special.py | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 58077923c7..7eba128100 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -43,7 +43,6 @@ upgrade_to_float_no_complex, ) from pytensor.scalar.loop import ScalarLoop -from pytensor.tensor.special import betaln class Erf(UnaryScalarOp): @@ -1736,6 +1735,14 @@ def c_code(self, *args, **kwargs): betaincinv = BetaIncInv(upgrade_to_float_no_complex, name="betaincinv") +def betaln(a, b): + """ + Beta function from gamma function. + """ + + return gammaln(a) + gammaln(b) - gammaln(a + b) + + class Hyp2F1(ScalarOp): """ Gaussian hypergeometric function ``2F1(a, b; c; z)``. diff --git a/tests/tensor/test_special.py b/tests/tensor/test_special.py index 24b624a8fb..a7448f1d86 100644 --- a/tests/tensor/test_special.py +++ b/tests/tensor/test_special.py @@ -195,8 +195,8 @@ def test_betaln(): a = random_ranged(0, 5, (2,)) b = random_ranged(0, 5, (2,)) - actual = actual_fn(a, b) - expected = np.exp(scipy_beta(a, b)) + actual = np.exp(actual_fn(a, b)) + expected = scipy_beta(a, b) np.testing.assert_allclose( actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5 ) From 908b5a104da70e8dc39259c0b4fb045fd747acd0 Mon Sep 17 00:00:00 2001 From: "Adriano M. Yoshino" Date: Tue, 2 Jan 2024 08:57:21 -0300 Subject: [PATCH 14/15] Update pytensor/tensor/special.py Default pre-commit to multi-line Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/special.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index 21cc50f078..e5884d0afe 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -768,4 +768,4 @@ def betaln(a, b): return gammaln(a) + gammaln(b) - gammaln(a + b) -__all__ = ["softmax", "log_softmax", "poch", "factorial", "beta", "betaln"] +__all__ = ["softmax", "log_softmax", "poch", "factorial", "beta", "betaln",] From aeab77398f85db384bc3000f220fa36e59baa654 Mon Sep 17 00:00:00 2001 From: amyoshino Date: Tue, 2 Jan 2024 09:01:12 -0300 Subject: [PATCH 15/15] running pre-commit --- pytensor/tensor/special.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index e5884d0afe..7b5e52d637 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -768,4 +768,11 @@ def betaln(a, b): return gammaln(a) + gammaln(b) - gammaln(a + b) -__all__ = ["softmax", "log_softmax", "poch", "factorial", "beta", "betaln",] +__all__ = [ + "softmax", + "log_softmax", + "poch", + "factorial", + "beta", + "betaln", +]