Skip to content

BUG: Regression in jax translation from 5.12 -> 5.13 #7263

Open
@velochy

Description

@velochy

Describe the issue:

I have a few models where I have to do some rather complex tensor manipulation, and moving from 5.12 to 5.13 quite a few of them broke down with JAX errors.

As the models themselves are big and unwieldy, I have tried to re-create the same issue with a toy example. As you can see, it needs to be quite convoluted to illicit the error (requiring a model dimension, a call to pt.concatenate and pt.set_subtensor), but I do run into it with more complex actual use cases as well.

I have managed to work around it i some cases by avoiding pt.concatenate and instead just creating an empty tensor and setting it's parts via set_subtensor, but I have one model where even that runs into issues. So it would be very nice if it worked like it used to before :)

The facts of the case:

  • Toy example works with 5.12
  • Toy example fails with 5.13.1
  • Toy example works if using normal sampler instead of numpyro_nuts

Reproduceable code example:

import pymc as pm
from pymc.sampling import jax as pm_jax
import pytensor.tensor as pt
import numpy as np

obs = np.array([
    [1,0,1,0,1,0],
    [0,1,1,0,1,0], 
])
ns = [3,3]

with pm.Model() as model:
    model.add_coord('mw',range(6))
     
    odds = pt.zeros( (len(ns),model.dim_lengths['mw']) )
    modds = pm.Normal('N',shape=(len(ns),model.dim_lengths['mw']//2 - 1))
    modds = pt.concatenate([pt.ones_like(modds[:,:1]),modds[:,:]],axis=1)

    odds = pt.set_subtensor(odds[:,[0,2,4]],modds)

    pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)

    pm_jax.sample_numpyro_nuts()

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/home/velochy/salk/salk_internal_package/experiments.ipynb Cell 1 line 2
     20 #odds = pt.set_subtensor(odds[:,[5,3,1]],modds)
     22 pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)
---> 24 pm_jax.sample_numpyro_nuts()

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/jax.py:567, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    564     raise ValueError(f"{nuts_sampler=} not recognized")
    566 tic1 = datetime.now()
--> 567 raw_mcmc_samples, sample_stats, library = sampler_fn(
    568     model=model,
    569     target_accept=target_accept,
    570     tune=tune,
    571     draws=draws,
    572     chains=chains,
    573     chain_method=chain_method,
    574     progressbar=progressbar,
    575     random_seed=random_seed,
    576     initial_points=initial_points,
    577     nuts_kwargs=nuts_kwargs,
    578 )
    579 tic2 = datetime.now()
    581 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/jax.py:484, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
    481 if chains > 1:
    482     map_seed = jax.random.split(map_seed, chains)
--> 484 pmap_numpyro.run(
    485     map_seed,
    486     init_params=initial_points,
    487     extra_fields=(
    488         "num_steps",
    489         "potential_energy",
    490         "energy",
    491         "adapt_state.step_size",
    492         "accept_prob",
    493         "diverging",
    494     ),
    495 )
    497 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
    498 sample_stats = _numpyro_stats_to_dict(pmap_numpyro)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/mcmc.py:650, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    648     states, last_state = _laxmap(partial_map_fn, map_args)
    649 elif self.chain_method == "parallel":
--> 650     states, last_state = pmap(partial_map_fn)(map_args)
    651 else:
    652     assert self.chain_method == "vectorized"

    [... skipping hidden 12 frame]

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/mcmc.py:426, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    424 # Check if _sample_fn is None, then we need to initialize the sampler.
    425 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 426     new_init_state = self.sampler.init(
    427         rng_key,
    428         self.num_warmup,
    429         init_params,
    430         model_args=args,
    431         model_kwargs=kwargs,
    432     )
    433     init_state = new_init_state if init_state is None else init_state
    434 sample_fn, postprocess_fn = self._get_cached_fns()

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc.py:783, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    763 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    764     init_params,
    765     num_warmup=num_warmup,
   (...)
    780     rng_key=rng_key,
    781 )
    782 if is_prng_key(rng_key):
--> 783     init_state = hmc_init_fn(init_params, rng_key)
    784 else:
    785     # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
    786     # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
    787     # wa_steps because those variables do not depend on traced args: init_params, rng_key.
    788     init_state = vmap(hmc_init_fn)(init_params, rng_key)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc.py:763, in HMC.init.<locals>.<lambda>(init_params, rng_key)
    760         dense_mass = [tuple(sorted(z))] if dense_mass else []
    761     assert isinstance(dense_mass, list)
--> 763 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    764     init_params,
    765     num_warmup=num_warmup,
    766     step_size=self._step_size,
    767     num_steps=self._num_steps,
    768     inverse_mass_matrix=inverse_mass_matrix,
    769     adapt_step_size=self._adapt_step_size,
    770     adapt_mass_matrix=self._adapt_mass_matrix,
    771     dense_mass=dense_mass,
    772     target_accept_prob=self._target_accept_prob,
    773     trajectory_length=self._trajectory_length,
    774     max_tree_depth=self._max_tree_depth,
    775     find_heuristic_step_size=self._find_heuristic_step_size,
    776     forward_mode_differentiation=self._forward_mode_differentiation,
    777     regularize_mass_matrix=self._regularize_mass_matrix,
    778     model_args=model_args,
    779     model_kwargs=model_kwargs,
    780     rng_key=rng_key,
    781 )
    782 if is_prng_key(rng_key):
    783     init_state = hmc_init_fn(init_params, rng_key)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc.py:336, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
    334 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
    335 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 336 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
    337 energy = vv_state.potential_energy + kinetic_fn(
    338     wa_state.inverse_mass_matrix, vv_state.r
    339 )
    340 zero_int = jnp.array(0, dtype=jnp.result_type(int))

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc_util.py:282, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
    274 """
    275 :param z: Position of the particle.
    276 :param r: Momentum of the particle.
   (...)
    279 :return: initial state for the integrator.
    280 """
    281 if potential_energy is None or z_grad is None:
--> 282     potential_energy, z_grad = _value_and_grad(
    283         potential_fn, z, forward_mode_differentiation
    284     )
    285 return IntegratorState(z, r, potential_energy, z_grad)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc_util.py:250, in _value_and_grad(f, x, forward_mode_differentiation)
    248     return out, grads
    249 else:
--> 250     return value_and_grad(f, has_aux=False)(x)

    [... skipping hidden 8 frame]

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/jax.py:156, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
    155 def logp_fn_wrap(x):
--> 156     return logp_fn(*x)[0]

File /tmp/tmpfrmeiqr6:11, in jax_funcified_fgraph(N)
      9 tensor_variable_3 = elemwise_fn_2(tensor_variable_2, tensor_constant_1)
     10 # Alloc([[1.]], 2, Sub.0)
---> 11 tensor_variable_4 = alloc(tensor_constant_2, tensor_constant_3, tensor_variable_3)
     12 # Join(1, Alloc.0, N)
     13 tensor_variable_5 = join(tensor_constant_4, tensor_variable_4, N)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:47, in jax_funcify_Alloc.<locals>.alloc(x, *shape)
     46 def alloc(x, *shape):
---> 47     res = jnp.broadcast_to(x, shape)
     48     Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
     49     return res

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:1222, in broadcast_to(array, shape)
   1218 @util.implements(np.broadcast_to, lax_description="""\
   1219 The JAX version does not necessarily return a view of the input.
   1220 """)
   1221 def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array:
-> 1222   return util._broadcast_to(array, shape)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/jax/_src/numpy/util.py:417, in _broadcast_to(arr, shape)
    415   shape = (shape,)
    416 # check that shape is concrete
--> 417 shape = core.canonicalize_shape(shape)  # type: ignore[arg-type]
    418 arr_shape = np.shape(arr)
    419 if core.definitely_equal_shape(arr_shape, shape):

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/jax/_src/core.py:2117, in canonicalize_shape(shape, context)
   2115 except TypeError:
   2116   pass
-> 2117 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (2, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function _single_chain_mcmc at /home/velochy/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/mcmc.py:422 for pmap. This value became a tracer due to JAX operations on these lines:

  operation a:bool[] = lt b c
    from line /tmp/tmpfrmeiqr6:5:24 (jax_funcified_fgraph)

  operation a:i64[] = pjit[
  name=_where
  jaxpr={ lambda ; b:bool[] c:i64[] d:i64[]. let
      e:i64[] = select_n b d c
    in (e,) }
] f g h
    from line /tmp/tmpfrmeiqr6:7:24 (jax_funcified_fgraph)

  operation a:i64[] = sub b c
    from line /tmp/tmpfrmeiqr6:9:24 (jax_funcified_fgraph)

PyMC version information:

Fails on 5.13.1

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions