From af7127c507fed80f554c7614254f6baf93696ff4 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 17 Jun 2024 16:23:17 +0200 Subject: [PATCH 1/4] deprecate samples arg in prior_predictive. closes #7173 --- pymc/sampling/forward.py | 18 ++++++++++++++--- tests/sampling/test_forward.py | 35 +++++++++++++++++++++------------- 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index ce51f5cc72..ca674f3155 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -338,19 +338,20 @@ def observed_dependent_deterministics(model: Model): def sample_prior_predictive( - samples: int = 500, + draws: int = 500, model: Model | None = None, var_names: Iterable[str] | None = None, random_seed: RandomState = None, return_inferencedata: bool = True, idata_kwargs: dict | None = None, compile_kwargs: dict | None = None, + samples: int | None = None, ) -> InferenceData | dict[str, np.ndarray]: """Generate samples from the prior predictive distribution. Parameters ---------- - samples : int + draws : int Number of samples from the prior predictive to generate. Defaults to 500. model : Model (optional if in ``with`` context) var_names : Iterable[str] @@ -366,6 +367,8 @@ def sample_prior_predictive( Keyword arguments for :func:`pymc.to_inference_data` compile_kwargs: dict, optional Keyword arguments for :func:`pymc.pytensorf.compile_pymc`. + samples : int + Number of samples from the prior predictive to generate. Deprecated in favor of `draws`. Returns ------- @@ -373,6 +376,15 @@ def sample_prior_predictive( An ArviZ ``InferenceData`` object containing the prior and prior predictive samples (default), or a dictionary with variable names as keys and samples as numpy arrays. """ + if samples is not None: + warnings.warn( + f"The samples argument has been deprecated in favor of draws. Use draws={samples} going forward.", + DeprecationWarning, + stacklevel=2, + ) + + draws = samples + model = modelcontext(model) if model.potentials: @@ -415,7 +427,7 @@ def sample_prior_predictive( # All model variables have a name, but mypy does not know this _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore - values = zip(*(sampler_fn() for i in range(samples))) + values = zip(*(sampler_fn() for i in range(draws))) data = {k: np.stack(v) for k, v in zip(names, values)} if data is None: diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 3466089472..89e9196cbb 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -794,12 +794,12 @@ def test_logging_sampled_basic_rvs_prior(self, caplog): z = pm.Normal("z", y, observed=0) with m: - pm.sample_prior_predictive(samples=1) + pm.sample_prior_predictive(draws=1) assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x, z]")] caplog.clear() with m: - pm.sample_prior_predictive(samples=1, var_names=["x"]) + pm.sample_prior_predictive(draws=1, var_names=["x"]) assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x]")] caplog.clear() @@ -1028,7 +1028,7 @@ def test_observed_data_needed_in_pp(self): mu = x_data.sum(-1) pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",)) - prior = pm.sample_prior_predictive(samples=25).prior + prior = pm.sample_prior_predictive(draws=25).prior fake_idata = InferenceData(posterior=prior) @@ -1052,7 +1052,7 @@ def test_observed_data_needed_in_pp(self): mu = (y_data.sum() * x_data).sum(-1) pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",)) - prior = pm.sample_prior_predictive(samples=25).prior + prior = pm.sample_prior_predictive(draws=25).prior fake_idata = InferenceData(posterior=prior) @@ -1135,7 +1135,7 @@ def test_multivariate2(self, seeded_test): compute_convergence_checks=False, ) sim_priors = pm.sample_prior_predictive( - return_inferencedata=False, samples=20, model=dm_model + return_inferencedata=False, draws=20, model=dm_model ) sim_ppc = pm.sample_posterior_predictive( burned_trace, return_inferencedata=False, model=dm_model @@ -1227,7 +1227,7 @@ def test_zeroinflatedpoisson(self): mu = pm.Beta("mu", alpha=1, beta=1) psi = pm.HalfNormal("psi", sigma=1) pm.ZeroInflatedPoisson("suppliers", psi=psi, mu=mu, size=20) - gen_data = pm.sample_prior_predictive(samples=5000) + gen_data = pm.sample_prior_predictive(draws=5000) assert gen_data.prior["mu"].shape == (1, 5000) assert gen_data.prior["psi"].shape == (1, 5000) assert gen_data.prior["suppliers"].shape == (1, 5000, 20) @@ -1240,7 +1240,7 @@ def test_potentials_warning(self): with m: with pytest.warns(UserWarning, match=warning_msg): - pm.sample_prior_predictive(samples=5) + pm.sample_prior_predictive(draws=5) def test_transformed_vars_not_supported(self): with pm.Model() as model: @@ -1260,7 +1260,7 @@ def test_issue_4490(self): c = pm.Normal("c") d = pm.Normal("d") prior1 = pm.sample_prior_predictive( - samples=1, var_names=["a", "b", "c", "d"], random_seed=seed + draws=1, var_names=["a", "b", "c", "d"], random_seed=seed ) with pm.Model() as m2: @@ -1269,7 +1269,7 @@ def test_issue_4490(self): c = pm.Normal("c") d = pm.Normal("d") prior2 = pm.sample_prior_predictive( - samples=1, var_names=["b", "a", "d", "c"], random_seed=seed + draws=1, var_names=["b", "a", "d", "c"], random_seed=seed ) assert prior1.prior["a"] == prior2.prior["a"] @@ -1284,7 +1284,7 @@ def test_pytensor_function_kwargs(self): y = pm.Deterministic("y", x + sharedvar) prior = pm.sample_prior_predictive( - samples=5, + draws=5, return_inferencedata=False, compile_kwargs=dict( mode=Mode("py"), @@ -1308,7 +1308,7 @@ def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture): with pmodel: prior = pm.sample_prior_predictive( - samples=20, + draws=20, return_inferencedata=False, ) idat = pm.to_inference_data(trace, prior=prior) @@ -1367,7 +1367,7 @@ def test_distinct_rvs(): Y_rv = pm.Normal("y") pp_samples = pm.sample_prior_predictive( - samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532) + draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532) ) assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0] @@ -1377,7 +1377,7 @@ def test_distinct_rvs(): Y_rv = pm.Normal("y") pp_samples_2 = pm.sample_prior_predictive( - samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532) + draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532) ) assert np.array_equal(pp_samples["y"], pp_samples_2["y"]) @@ -1706,3 +1706,12 @@ def test_observed_dependent_deterministics(): det_mixed = pm.Deterministic("det_mixed", free + obs) assert set(observed_dependent_deterministics(m)) == {det_obs, det_obs2, det_mixed} + + +def test_sample_prior_predictive_samples_deprecated_warns() -> None: + with pm.Model() as m: + pm.Normal("a") + + match = "The samples argument has been deprecated" + with pytest.warns(DeprecationWarning, match=match): + pm.sample_prior_predictive(model=m, samples=10) From 185e6f3206429dea095aa422ba049c786c4c04e9 Mon Sep 17 00:00:00 2001 From: Will Dean <57733339+wd60622@users.noreply.github.com> Date: Mon, 17 Jun 2024 16:42:04 +0200 Subject: [PATCH 2/4] reduce stacklevel Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/sampling/forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index ca674f3155..13696c8c49 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -380,7 +380,7 @@ def sample_prior_predictive( warnings.warn( f"The samples argument has been deprecated in favor of draws. Use draws={samples} going forward.", DeprecationWarning, - stacklevel=2, + stacklevel=1, ) draws = samples From d0f61a2fffdc1f9e7f630dbad7ec43701238f49e Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 17 Jun 2024 16:49:38 +0200 Subject: [PATCH 3/4] change samples -> draws --- tests/distributions/test_mixture.py | 10 +++++----- tests/distributions/test_multivariate.py | 6 +++--- tests/gp/test_hsgp_approx.py | 8 ++++---- tests/model/test_core.py | 2 +- tests/sampling/test_deterministic.py | 2 +- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/distributions/test_mixture.py b/tests/distributions/test_mixture.py index e328c96931..c44796fc1c 100644 --- a/tests/distributions/test_mixture.py +++ b/tests/distributions/test_mixture.py @@ -561,7 +561,7 @@ def test_single_poisson_predictive_sampling_shape(self): n_samples = 30 with model: - prior = sample_prior_predictive(samples=n_samples, return_inferencedata=False) + prior = sample_prior_predictive(draws=n_samples, return_inferencedata=False) ppc = sample_posterior_predictive( n_samples * [self.get_initial_point(model)], return_inferencedata=False ) @@ -607,7 +607,7 @@ def test_list_mvnormals_predictive_sampling_shape(self): n_samples = 20 with model: - prior = sample_prior_predictive(samples=n_samples, return_inferencedata=False) + prior = sample_prior_predictive(draws=n_samples, return_inferencedata=False) ppc = sample_posterior_predictive( n_samples * [self.get_initial_point(model)], return_inferencedata=False ) @@ -1028,7 +1028,7 @@ def test_with_multinomial(self, seeded_test, batch_shape): comp_dists=comp_dists, shape=(*batch_shape, 3), ) - prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False) + prior = sample_prior_predictive(draws=self.n_samples, return_inferencedata=False) assert prior["mixture"].shape == (self.n_samples, *batch_shape, 3) assert draw(mixture, draws=self.size).shape == (self.size, *batch_shape, 3) @@ -1060,7 +1060,7 @@ def test_with_mvnormal(self, seeded_test): with Model() as model: comp_dists = MvNormal.dist(mu=mu, chol=chol, shape=(self.mixture_comps, 3)) mixture = Mixture("mixture", w=w, comp_dists=comp_dists, shape=(3,)) - prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False) + prior = sample_prior_predictive(draws=self.n_samples, return_inferencedata=False) assert prior["mixture"].shape == (self.n_samples, 3) assert draw(mixture, draws=self.size).shape == (self.size, 3) @@ -1084,7 +1084,7 @@ def test_broadcasting_in_shape(self): mu = Gamma("mu", 1.0, 1.0, shape=2) comp_dists = Poisson.dist(mu, shape=2) mix = Mixture("mix", w=np.ones(2) / 2, comp_dists=comp_dists, shape=(1000,)) - prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False) + prior = sample_prior_predictive(draws=self.n_samples, return_inferencedata=False) assert prior["mix"].shape == (self.n_samples, 1000) diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 6ae0169cd4..2d5b0e7585 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1448,7 +1448,7 @@ def test_with_chol_rv(self): "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True ) mv = pm.MvNormal("mv", mu, chol=chol, size=4) - prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False) + prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False) assert prior["mv"].shape == (10, 4, 3) @@ -1462,7 +1462,7 @@ def test_with_cov_rv( "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True ) mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), size=4) - prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False) + prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False) assert prior["mv"].shape == (10, 4, 3) @@ -1473,7 +1473,7 @@ def test_with_lkjcorr_matrix( corr = pm.LKJCorr("corr", n=3, eta=2, return_matrix=True) pm.Deterministic("corr_mat", corr) mv = pm.MvNormal("mv", 0.0, cov=corr, size=4) - prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False) + prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False) assert prior["corr_mat"].shape == (10, 3, 3) # square assert (prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]] == 1.0).all() # 1.0 on diagonal diff --git a/tests/gp/test_hsgp_approx.py b/tests/gp/test_hsgp_approx.py index d80db1bb76..5033aa8260 100644 --- a/tests/gp/test_hsgp_approx.py +++ b/tests/gp/test_hsgp_approx.py @@ -213,7 +213,7 @@ def test_prior(self, model, cov_func, X1, parametrization, rng): gp = pm.gp.Latent(cov_func=cov_func) f2 = gp.prior("f2", X=X1) - idata = pm.sample_prior_predictive(samples=1000, random_seed=rng) + idata = pm.sample_prior_predictive(draws=1000, random_seed=rng) samples1 = az.extract(idata.prior["f1"])["f1"].values.T samples2 = az.extract(idata.prior["f2"])["f2"].values.T @@ -240,7 +240,7 @@ def test_conditional(self, model, cov_func, X1, parametrization): f = hsgp.prior("f", X=X1) fc = hsgp.conditional("fc", Xnew=X1) - idata = pm.sample_prior_predictive(samples=1000) + idata = pm.sample_prior_predictive(draws=1000) samples1 = az.extract(idata.prior["f"])["f"].values.T samples2 = az.extract(idata.prior["fc"])["fc"].values.T @@ -300,7 +300,7 @@ def test_prior(self, model, cov_func, eta, X1, rng): gp = pm.gp.Latent(cov_func=eta**2 * cov_func) f2 = gp.prior("f2", X=X1) - idata = pm.sample_prior_predictive(samples=1000, random_seed=rng) + idata = pm.sample_prior_predictive(draws=1000, random_seed=rng) samples1 = az.extract(idata.prior["f1"])["f1"].values.T samples2 = az.extract(idata.prior["f2"])["f2"].values.T @@ -321,7 +321,7 @@ def test_conditional_periodic(self, model, cov_func, X1): f = hsgp.prior("f", X=X1) fc = hsgp.conditional("fc", Xnew=X1) - idata = pm.sample_prior_predictive(samples=1000) + idata = pm.sample_prior_predictive(draws=1000) samples1 = az.extract(idata.prior["f"])["f"].values.T samples2 = az.extract(idata.prior["fc"])["fc"].values.T diff --git a/tests/model/test_core.py b/tests/model/test_core.py index c00250b739..a4f4255e73 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -873,7 +873,7 @@ def test_none_coords_autonumbering(self): m.add_coord(name="a", values=None, length=3) m.add_coord(name="b", values=range(5)) x = pm.Normal("x", dims=("a", "b")) - prior = pm.sample_prior_predictive(samples=2).prior + prior = pm.sample_prior_predictive(draws=2).prior assert prior["x"].shape == (1, 2, 3, 5) assert list(prior.coords["a"].values) == list(range(3)) assert list(prior.coords["b"].values) == list(range(5)) diff --git a/tests/sampling/test_deterministic.py b/tests/sampling/test_deterministic.py index d1d2b8474c..f42e1d7eba 100644 --- a/tests/sampling/test_deterministic.py +++ b/tests/sampling/test_deterministic.py @@ -34,7 +34,7 @@ def test_compute_deterministics(): sigma = Deterministic("sigma", sigma_raw.exp()) dataset = sample_prior_predictive( - samples=5, model=m, var_names=["mu_raw", "sigma_raw"], random_seed=22 + draws=5, model=m, var_names=["mu_raw", "sigma_raw"], random_seed=22 ).prior # Test default From 627ad5f63714a5a82702e85b274a6da300995028 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 17 Jun 2024 16:51:56 +0200 Subject: [PATCH 4/4] change in notebook --- docs/source/learn/core_notebooks/posterior_predictive.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/learn/core_notebooks/posterior_predictive.ipynb b/docs/source/learn/core_notebooks/posterior_predictive.ipynb index faaa92b732..a648cb19b2 100644 --- a/docs/source/learn/core_notebooks/posterior_predictive.ipynb +++ b/docs/source/learn/core_notebooks/posterior_predictive.ipynb @@ -156,7 +156,7 @@ " sigma = pm.Exponential(\"sigma\", 1.0)\n", "\n", " pm.Normal(\"obs\", mu=mu, sigma=sigma, observed=outcome_scaled)\n", - " idata = pm.sample_prior_predictive(samples=50, random_seed=rng)" + " idata = pm.sample_prior_predictive(draws=50, random_seed=rng)" ] }, { @@ -225,7 +225,7 @@ " sigma = pm.Exponential(\"sigma\", 1.0)\n", "\n", " pm.Normal(\"obs\", mu=mu, sigma=sigma, observed=outcome_scaled)\n", - " idata = pm.sample_prior_predictive(samples=50, random_seed=rng)" + " idata = pm.sample_prior_predictive(draws=50, random_seed=rng)" ] }, {