Skip to content

NUTS fails to initalize with multi-dim catergorical model #3564

@bdyetton

Description

@bdyetton

I've just switched to the master branch so i can take advantage of this bug fix: #3535.
However, i'm getting:

Initializing NUTS failed. Falling back to elementwise auto-assignment.

on some models.
Its seems to be to do with having a 2-dimensional shape for a categorical (and therefore makes me think it's related to #3535)

Note the two models below, where the top one is fine (besides the divergences), the bottom one fails:

data = np.random.randint(0, 3, size=(1000, 1))

    with pm.Model() as model:
        tp1 = pm.Dirichlet('tp1', a=np.array([0.25]*4), shape=(4,))
        obs = pm.Categorical('obs', p=tp1, observed=data)
        trace = pm.sample()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [tp1]
Sampling 4 chains, 221 divergences: 100%|| 4000/4000 [00:01<00:00, 3044.22draws/s]
There were 64 divergences after tuning. Increase `target_accept` or reparameterize.
There were 49 divergences after tuning. Increase `target_accept` or reparameterize.
There were 56 divergences after tuning. Increase `target_accept` or reparameterize.
There were 52 divergences after tuning. Increase `target_accept` or reparameterize.
data = np.random.randint(0,3,size=(1000,2))
with pm.Model() as model:
    tp1 = pm.Dirichlet('tp1', a=np.array([0.25]*4), shape=(4,4)) #difference is the shape here here
    obs = pm.Categorical('obs', p=tp1[data[:,0],:], observed=data[:,1]) #and shape here
    trace = pm.sample()

which produces:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Initializing NUTS failed. Falling back to elementwise auto-assignment.
Multiprocess sampling (4 chains in 4 jobs)
Slice: [tp1]
Sampling 4 chains, 0 divergences: 100%|| 4000/4000 [00:08<00:00, 454.37draws/s]

Versions and main components

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions