Description
Describe the issue:
I have a few models where I have to do some rather complex tensor manipulation, and moving from 5.12 to 5.13 quite a few of them broke down with JAX errors.
As the models themselves are big and unwieldy, I have tried to re-create the same issue with a toy example. As you can see, it needs to be quite convoluted to illicit the error (requiring a model dimension, a call to pt.concatenate and pt.set_subtensor), but I do run into it with more complex actual use cases as well.
I have managed to work around it i some cases by avoiding pt.concatenate and instead just creating an empty tensor and setting it's parts via set_subtensor, but I have one model where even that runs into issues. So it would be very nice if it worked like it used to before :)
The facts of the case:
- Toy example works with 5.12
- Toy example fails with 5.13.1
- Toy example works if using normal sampler instead of numpyro_nuts
Reproduceable code example:
import pymc as pm
from pymc.sampling import jax as pm_jax
import pytensor.tensor as pt
import numpy as np
obs = np.array([
[1,0,1,0,1,0],
[0,1,1,0,1,0],
])
ns = [3,3]
with pm.Model() as model:
model.add_coord('mw',range(6))
odds = pt.zeros( (len(ns),model.dim_lengths['mw']) )
modds = pm.Normal('N',shape=(len(ns),model.dim_lengths['mw']//2 - 1))
modds = pt.concatenate([pt.ones_like(modds[:,:1]),modds[:,:]],axis=1)
odds = pt.set_subtensor(odds[:,[0,2,4]],modds)
pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)
pm_jax.sample_numpyro_nuts()
Error message:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/home/velochy/salk/salk_internal_package/experiments.ipynb Cell 1 line 2
20 #odds = pt.set_subtensor(odds[:,[5,3,1]],modds)
22 pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)
---> 24 pm_jax.sample_numpyro_nuts()
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/jax.py:567, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
564 raise ValueError(f"{nuts_sampler=} not recognized")
566 tic1 = datetime.now()
--> 567 raw_mcmc_samples, sample_stats, library = sampler_fn(
568 model=model,
569 target_accept=target_accept,
570 tune=tune,
571 draws=draws,
572 chains=chains,
573 chain_method=chain_method,
574 progressbar=progressbar,
575 random_seed=random_seed,
576 initial_points=initial_points,
577 nuts_kwargs=nuts_kwargs,
578 )
579 tic2 = datetime.now()
581 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/jax.py:484, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
481 if chains > 1:
482 map_seed = jax.random.split(map_seed, chains)
--> 484 pmap_numpyro.run(
485 map_seed,
486 init_params=initial_points,
487 extra_fields=(
488 "num_steps",
489 "potential_energy",
490 "energy",
491 "adapt_state.step_size",
492 "accept_prob",
493 "diverging",
494 ),
495 )
497 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
498 sample_stats = _numpyro_stats_to_dict(pmap_numpyro)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/mcmc.py:650, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
648 states, last_state = _laxmap(partial_map_fn, map_args)
649 elif self.chain_method == "parallel":
--> 650 states, last_state = pmap(partial_map_fn)(map_args)
651 else:
652 assert self.chain_method == "vectorized"
[... skipping hidden 12 frame]
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/mcmc.py:426, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
424 # Check if _sample_fn is None, then we need to initialize the sampler.
425 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 426 new_init_state = self.sampler.init(
427 rng_key,
428 self.num_warmup,
429 init_params,
430 model_args=args,
431 model_kwargs=kwargs,
432 )
433 init_state = new_init_state if init_state is None else init_state
434 sample_fn, postprocess_fn = self._get_cached_fns()
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc.py:783, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
763 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
764 init_params,
765 num_warmup=num_warmup,
(...)
780 rng_key=rng_key,
781 )
782 if is_prng_key(rng_key):
--> 783 init_state = hmc_init_fn(init_params, rng_key)
784 else:
785 # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
786 # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
787 # wa_steps because those variables do not depend on traced args: init_params, rng_key.
788 init_state = vmap(hmc_init_fn)(init_params, rng_key)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc.py:763, in HMC.init.<locals>.<lambda>(init_params, rng_key)
760 dense_mass = [tuple(sorted(z))] if dense_mass else []
761 assert isinstance(dense_mass, list)
--> 763 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
764 init_params,
765 num_warmup=num_warmup,
766 step_size=self._step_size,
767 num_steps=self._num_steps,
768 inverse_mass_matrix=inverse_mass_matrix,
769 adapt_step_size=self._adapt_step_size,
770 adapt_mass_matrix=self._adapt_mass_matrix,
771 dense_mass=dense_mass,
772 target_accept_prob=self._target_accept_prob,
773 trajectory_length=self._trajectory_length,
774 max_tree_depth=self._max_tree_depth,
775 find_heuristic_step_size=self._find_heuristic_step_size,
776 forward_mode_differentiation=self._forward_mode_differentiation,
777 regularize_mass_matrix=self._regularize_mass_matrix,
778 model_args=model_args,
779 model_kwargs=model_kwargs,
780 rng_key=rng_key,
781 )
782 if is_prng_key(rng_key):
783 init_state = hmc_init_fn(init_params, rng_key)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc.py:336, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
334 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
335 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 336 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
337 energy = vv_state.potential_energy + kinetic_fn(
338 wa_state.inverse_mass_matrix, vv_state.r
339 )
340 zero_int = jnp.array(0, dtype=jnp.result_type(int))
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc_util.py:282, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
274 """
275 :param z: Position of the particle.
276 :param r: Momentum of the particle.
(...)
279 :return: initial state for the integrator.
280 """
281 if potential_energy is None or z_grad is None:
--> 282 potential_energy, z_grad = _value_and_grad(
283 potential_fn, z, forward_mode_differentiation
284 )
285 return IntegratorState(z, r, potential_energy, z_grad)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc_util.py:250, in _value_and_grad(f, x, forward_mode_differentiation)
248 return out, grads
249 else:
--> 250 return value_and_grad(f, has_aux=False)(x)
[... skipping hidden 8 frame]
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/jax.py:156, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
155 def logp_fn_wrap(x):
--> 156 return logp_fn(*x)[0]
File /tmp/tmpfrmeiqr6:11, in jax_funcified_fgraph(N)
9 tensor_variable_3 = elemwise_fn_2(tensor_variable_2, tensor_constant_1)
10 # Alloc([[1.]], 2, Sub.0)
---> 11 tensor_variable_4 = alloc(tensor_constant_2, tensor_constant_3, tensor_variable_3)
12 # Join(1, Alloc.0, N)
13 tensor_variable_5 = join(tensor_constant_4, tensor_variable_4, N)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:47, in jax_funcify_Alloc.<locals>.alloc(x, *shape)
46 def alloc(x, *shape):
---> 47 res = jnp.broadcast_to(x, shape)
48 Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
49 return res
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:1222, in broadcast_to(array, shape)
1218 @util.implements(np.broadcast_to, lax_description="""\
1219 The JAX version does not necessarily return a view of the input.
1220 """)
1221 def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array:
-> 1222 return util._broadcast_to(array, shape)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/jax/_src/numpy/util.py:417, in _broadcast_to(arr, shape)
415 shape = (shape,)
416 # check that shape is concrete
--> 417 shape = core.canonicalize_shape(shape) # type: ignore[arg-type]
418 arr_shape = np.shape(arr)
419 if core.definitely_equal_shape(arr_shape, shape):
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/jax/_src/core.py:2117, in canonicalize_shape(shape, context)
2115 except TypeError:
2116 pass
-> 2117 raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (2, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function _single_chain_mcmc at /home/velochy/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/mcmc.py:422 for pmap. This value became a tracer due to JAX operations on these lines:
operation a:bool[] = lt b c
from line /tmp/tmpfrmeiqr6:5:24 (jax_funcified_fgraph)
operation a:i64[] = pjit[
name=_where
jaxpr={ lambda ; b:bool[] c:i64[] d:i64[]. let
e:i64[] = select_n b d c
in (e,) }
] f g h
from line /tmp/tmpfrmeiqr6:7:24 (jax_funcified_fgraph)
operation a:i64[] = sub b c
from line /tmp/tmpfrmeiqr6:9:24 (jax_funcified_fgraph)
PyMC version information:
Fails on 5.13.1
Context for the issue:
No response