diff --git a/docs/source/api/math.rst b/docs/source/api/math.rst index 67b487194d..260e58f214 100644 --- a/docs/source/api/math.rst +++ b/docs/source/api/math.rst @@ -19,8 +19,10 @@ Functions exposed in pymc namespace invlogit probit invprobit + logaddexp logsumexp + Functions exposed in pymc.math ------------------------------ @@ -28,47 +30,87 @@ Functions exposed in pymc.math .. autosummary:: :toctree: generated/ - dot - constant - flatten - zeros_like - ones_like - stack - concatenate - sum + abs prod - lt - gt - le - ge + dot eq neq - switch - clip - where - and_ - or_ - abs + ge + gt + le + lt exp log - cos + sgn + sqr + sqrt + sum + ceil + floor sin - tan - cosh sinh + arcsin + arcsinh + cos + cosh + arccos + arccosh + tan tanh - sqr - sqrt - erf - erfinv - dot + arctan + arctanh + cumprod + cumsum + matmul + and_ + broadcast_to + clip + concatenate + flatten + or_ + stack + switch + where + flatten_list + constant + max maximum + mean + min minimum - sgn - ceil - floor - matrix_inverse - sigmoid + round + erf + erfc + erfcinv + erfinv + log1pexp + log1mexp + logaddexp logsumexp - invlogit + logdiffexp logit + invlogit + probit + invprobit + sigmoid + softmax + log_softmax + logbern + full + full_like + ones + ones_like + zeros + zeros_like + kronecker + cartesian + kron_dot + kron_solve_lower + kron_solve_upper + kron_diag + flat_outer + expand_packed_triangular + batched_diag + block_diagonal + matrix_inverse + logdet diff --git a/pymc/math.py b/pymc/math.py index 7fe8d1e5e5..b85ffe63ce 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -73,6 +73,7 @@ ones_like, or_, prod, + round, sgn, sigmoid, sin, @@ -178,6 +179,7 @@ "expand_packed_triangular", "batched_diag", "block_diagonal", + "round", ] @@ -272,20 +274,6 @@ def kron_diag(*diags): return reduce(flat_outer, diags) -def round(*args, **kwargs): - """ - Temporary function to silence round warning in PyTensor. Please remove - when the warning disappears. - """ - kwargs["mode"] = "half_to_even" - return pt.round(*args, **kwargs) - - -def tround(*args, **kwargs): - warnings.warn("tround is deprecated. Use round instead.") - return round(*args, **kwargs) - - def logdiffexp(a, b): """log(exp(a) - exp(b))""" return a + pt.log1mexp(b - a) @@ -293,6 +281,11 @@ def logdiffexp(a, b): def logdiffexp_numpy(a, b): """log(exp(a) - exp(b))""" + warnings.warn( + "pymc.math.logdiffexp_numpy is being deprecated.", + FutureWarning, + stacklevel=2, + ) return a + log1mexp_numpy(b - a, negative_input=True) @@ -341,6 +334,11 @@ def log1mexp_numpy(x, *, negative_input=False): For details, see https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf """ + warnings.warn( + "pymc.math.log1mexp_numpy is being deprecated.", + FutureWarning, + stacklevel=2, + ) x = np.asarray(x, dtype="float") if not negative_input: diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 901618dc28..91d6cdbaf6 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -417,8 +417,11 @@ def test_kumaraswamy(self): def scipy_log_pdf(value, a, b): return np.log(a) + np.log(b) + (a - 1) * np.log(value) + (b - 1) * np.log(1 - value**a) + def log1mexp(x): + return np.log1p(-np.exp(x)) if x < np.log(0.5) else np.log(-np.expm1(x)) + def scipy_log_cdf(value, a, b): - return pm.math.log1mexp_numpy(b * np.log1p(-(value**a)), negative_input=True) + return log1mexp(b * np.log1p(-(value**a))) check_logp( pm.Kumaraswamy, diff --git a/tests/test_math.py b/tests/test_math.py index 544bf4ce93..40c3b70db5 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -145,45 +145,46 @@ def test_log1mexp(): ) actual = pt.log1mexp(-vals).eval() npt.assert_allclose(actual, expected) + with warnings.catch_warnings(): warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) warnings.filterwarnings("ignore", "invalid value encountered in log", RuntimeWarning) - actual_ = log1mexp_numpy(-vals, negative_input=True) + with pytest.warns(FutureWarning, match="deprecated"): + actual_ = log1mexp_numpy(-vals, negative_input=True) npt.assert_allclose(actual_, expected) # Check that input was not changed in place npt.assert_allclose(vals, vals_) +@pytest.mark.filterwarnings("error") def test_log1mexp_numpy_no_warning(): """Assert RuntimeWarning is not raised for very small numbers""" - with warnings.catch_warnings(): - warnings.simplefilter("error") + with pytest.warns(FutureWarning, match="deprecated"): log1mexp_numpy(-1e-25, negative_input=True) def test_log1mexp_numpy_integer_input(): - assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval()) + with pytest.warns(FutureWarning, match="deprecated"): + assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval()) +@pytest.mark.filterwarnings("error") def test_log1mexp_deprecation_warnings(): - with pytest.warns( - FutureWarning, - match="pymc.math.log1mexp_numpy will expect a negative input", - ): - res_pos = log1mexp_numpy(2) + with pytest.warns(FutureWarning, match="deprecated"): + with pytest.warns( + FutureWarning, + match="pymc.math.log1mexp_numpy will expect a negative input", + ): + res_pos = log1mexp_numpy(2) - with warnings.catch_warnings(): - warnings.simplefilter("error") res_neg = log1mexp_numpy(-2, negative_input=True) - with pytest.warns( - FutureWarning, - match="pymc.math.log1mexp will expect a negative input", - ): - res_pos_at = log1mexp(2).eval() + with pytest.warns( + FutureWarning, + match="pymc.math.log1mexp will expect a negative input", + ): + res_pos_at = log1mexp(2).eval() - with warnings.catch_warnings(): - warnings.simplefilter("error") res_neg_at = log1mexp(-2, negative_input=True).eval() assert np.isclose(res_pos, res_neg) @@ -196,8 +197,8 @@ def test_logdiffexp(): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) b = np.log([0, 1, 2, 3]) - - assert np.allclose(logdiffexp_numpy(a, b), 0) + with pytest.warns(FutureWarning, match="deprecated"): + assert np.allclose(logdiffexp_numpy(a, b), 0) assert np.allclose(logdiffexp(a, b).eval(), 0)