Skip to content

sample_smc #319

@myravian

Description

@myravian

Hi there,

I was testing pymc_experimental/inference/smc/sampling.py and noticed the following issues:

  • the inference doesn't seem to like pm.Dirichlet, with a shape error at tmp = logp_fn(*[p.squeeze() for p in particles])[0]
  • arviz_from_particles doesn't seem to like RVs with shape=(1,)
  • the conversion from inferencedata to netCDF fails because the integrations is neither int nor np.array
  • the inferencedata doesn't have the marginal likelihood, do you think it will be implemented in the future or it's just not possible?

Thanks a lot for the SMC blackjax implementation, it's very useful!

Cheers,
VIan

PS: here's some code that produces the error

`
real_a = 0.2
real_b = 2
x = np.linspace(1, 100)
y = real_a * x + real_b + np.random.normal(0, 2, len(x))

with pm.Model() as model:
a = pm.Normal("a", mu=10, sigma=10)
b = pm.Normal("b", mu=10, sigma=10)
# either of the following lines produces an error
# c = pm.Normal("c", mu=10, sigma=10, shape=(1,))
# d = pm.Dirichlet("d", [1, 1])

trace = sample_smc(
    n_particles=1000,
    kernel="HMC",
    inner_kernel_params={
        "step_size": 0.01, 
        "integration_steps": 20,
    },
    iterations_to_diagnose=10,
    target_essn=0.5,
    num_mcmc_steps=10,
)

`

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions