Open
Description
Describe the issue:
Sampling a GP model using PyTensor with compile_kwargs={"mode": "pytorch"}
results in a NotImplementedError associated with the clip
Op. Error below.
Reproduceable code example:
def build_gp_latent_dataset(n=200, random_seed=42):
"""
Generate data from a Gaussian Process with Student-T distributed noise.
This creates a challenging latent variable problem that tests the samplers'
ability to efficiently explore the high-dimensional posterior over the
latent GP function values.
"""
rng_local = np.random.default_rng(random_seed)
# Input locations
X = np.linspace(0, 10, n)[:, None]
# True GP hyperparameters
ell_true = 1.0 # lengthscale
eta_true = 4.0 # scale
# Create true covariance function and sample from GP prior
cov_func = eta_true**2 * pm.gp.cov.ExpQuad(1, ell_true)
mean_func = pm.gp.mean.Zero()
# Sample latent function values from GP prior with jitter for numerical stability
K = cov_func(X).eval()
# Add jitter to diagonal for numerical stability
K += 1e-6 * np.eye(n)
f_true = pm.draw(pm.MvNormal.dist(mu=mean_func(X), cov=K), 1, random_seed=rng_local)
# Add Student-T distributed noise (heavier tails than Gaussian)
sigma_true = 1.0
nu_true = 5.0 # degrees of freedom
y = f_true + sigma_true * rng_local.standard_t(df=nu_true, size=n)
print(f"Generated GP data with {n} points")
print(f"True hyperparameters: lengthscale={ell_true}, scale={eta_true}")
print(f"Noise: σ={sigma_true}, ν={nu_true} (Student-T)")
return X, y, f_true
# Generate the challenging GP dataset
N = 100 # number of data points
X, y_obs, f_true = build_gp_latent_dataset(N)
with pm.Model() as model:
ell = pm.Gamma("ell", alpha=2, beta=1)
eta = pm.HalfNormal("eta", sigma=5)
cov = eta**2 * pm.gp.cov.ExpQuad(1, ell)
gp = pm.gp.Latent(cov_func=cov)
f = gp.prior("f", X=X)
sigma = pm.HalfNormal("sigma", sigma=2.0)
nu = 1 + pm.Gamma("nu", alpha=2, beta=0.1)
_ = pm.StudentT("y", mu=f, lam=1.0/sigma, nu=nu, observed=y_obs)
idata_pytensor_pytorch = pm.sample(draws=500, tune=1000, chains=4, compile_kwargs={"mode": "pytorch"}, progressbar=False)
Error message:
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[20], line 8
6 with TimingContext("PyTensor PyTorch"):
7 with model:
----> 8 idata_pytensor_pytorch = pm.sample(draws=n_draws, tune=n_tune, chains=n_chains, compile_kwargs={"mode": "pytorch"}, progressbar=False)
10 ess_pytensor_pytorch = az.ess(idata_pytensor_pytorch)
11 min_ess = min([ess_pytensor_pytorch[var].values.min() for var in ess_pytensor_pytorch.data_vars])
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:832, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
830 [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
831 with joined_blas_limiter():
--> 832 initial_points, step = init_nuts(
833 init=init,
834 chains=chains,
835 n_init=n_init,
836 model=model,
837 random_seed=random_seed_list,
838 progressbar=progress_bool,
839 jitter_max_retries=jitter_max_retries,
840 tune=tune,
841 initvals=initvals,
842 compile_kwargs=compile_kwargs,
843 **kwargs,
844 )
845 else:
846 # Get initial points
847 ipfns = make_initial_point_fns_per_chain(
848 model=model,
849 overrides=initvals,
850 jitter_rvs=set(),
851 chains=chains,
852 )
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1598, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, compile_kwargs, **kwargs)
1592 if "advi" in init:
1593 cb = [
1594 pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
1595 pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
1596 ]
-> 1598 logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs)
1599 logp_dlogp_func.trust_input = True
1601 def model_logp_fn(ip: PointType) -> np.ndarray:
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pymc/model/core.py:572, in Model.logp_dlogp_function(self, grad_vars, tempered, initial_point, ravel_inputs, **kwargs)
566 initial_point = self.initial_point(0)
567 extra_vars_and_values = {
568 var: initial_point[var.name]
569 for var in self.value_vars
570 if var in input_vars and var not in grad_vars
571 }
--> 572 return ValueGradFunction(
573 costs,
574 grad_vars,
575 extra_vars_and_values,
576 model=self,
577 initial_point=initial_point,
578 ravel_inputs=ravel_inputs,
579 **kwargs,
580 )
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pymc/model/core.py:256, in ValueGradFunction.__init__(self, costs, grad_vars, extra_vars_and_values, dtype, casting, compute_grads, model, initial_point, ravel_inputs, **kwargs)
250 warnings.warn(
251 "ValueGradFunction will become a function of raveled inputs.\n"
252 "Specify `ravel_inputs` to suppress this warning. Note that setting `ravel_inputs=False` will be forbidden in a future release."
253 )
254 inputs = grad_vars
--> 256 self._pytensor_function = compile(inputs, outputs, givens=givens, **kwargs)
257 self._raveled_inputs = ravel_inputs
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pymc/pytensorf.py:947, in compile(inputs, outputs, random_seed, mode, **kwargs)
945 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
946 mode = Mode(linker=mode.linker, optimizer=opt_qry)
--> 947 pytensor_function = pytensor.function(
948 inputs,
949 outputs,
950 updates={**rng_updates, **kwargs.pop("updates", {})},
951 mode=mode,
952 **kwargs,
953 )
954 return pytensor_function
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/compile/function/__init__.py:332, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, trust_input)
321 fn = orig_function(
322 inputs,
323 outputs,
(...) 327 trust_input=trust_input,
328 )
329 else:
330 # note: pfunc will also call orig_function -- orig_function is
331 # a choke point that all compilation must pass through
--> 332 fn = pfunc(
333 params=inputs,
334 outputs=outputs,
335 mode=mode,
336 updates=updates,
337 givens=givens,
338 no_default_updates=no_default_updates,
339 accept_inplace=accept_inplace,
340 name=name,
341 rebuild_strict=rebuild_strict,
342 allow_input_downcast=allow_input_downcast,
343 on_unused_input=on_unused_input,
344 profile=profile,
345 output_keys=output_keys,
346 trust_input=trust_input,
347 )
348 return fn
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:466, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph, trust_input)
452 profile = ProfileStats(message=profile)
454 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
455 params,
456 outputs,
(...) 463 fgraph=fgraph,
464 )
--> 466 return orig_function(
467 inputs,
468 cloned_outputs,
469 mode,
470 accept_inplace=accept_inplace,
471 name=name,
472 profile=profile,
473 on_unused_input=on_unused_input,
474 output_keys=output_keys,
475 fgraph=fgraph,
476 trust_input=trust_input,
477 )
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/compile/function/types.py:1833, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph, trust_input)
1820 m = Maker(
1821 inputs,
1822 outputs,
(...) 1830 trust_input=trust_input,
1831 )
1832 with config.change_flags(compute_test_value="off"):
-> 1833 fn = m.create(defaults)
1834 finally:
1835 if profile and fn:
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/compile/function/types.py:1717, in FunctionMaker.create(self, input_storage, storage_map)
1714 start_import_time = pytensor.link.c.cmodule.import_time
1716 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1717 _fn, _i, _o = self.linker.make_thunk(
1718 input_storage=input_storage_lists, storage_map=storage_map
1719 )
1721 end_linker = time.perf_counter()
1723 linker_time = end_linker - start_linker
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
238 def make_thunk(
239 self,
240 input_storage: Optional["InputStorageType"] = None,
(...) 243 **kwargs,
244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 245 return self.make_all(
246 input_storage=input_storage,
247 output_storage=output_storage,
248 storage_map=storage_map,
249 )[:3]
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
692 for k in storage_map:
693 compute_map[k] = [k.owner is None]
--> 695 thunks, nodes, jit_fn = self.create_jitable_thunk(
696 compute_map, nodes, input_storage, output_storage, storage_map
697 )
699 [fn] = thunks
700 fn.jit_fn = jit_fn
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
644 # This is a bit hackish, but we only return one of the output nodes
645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
648 self.fgraph,
649 order=order,
650 input_storage=input_storage,
651 output_storage=output_storage,
652 storage_map=storage_map,
653 )
655 thunk_inputs = self.create_thunk_inputs(storage_map)
656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/pytorch/linker.py:33, in PytorchLinker.fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs)
26 return functor
28 built_kwargs = {
29 "unique_name": generator,
30 "conversion_func": conversion_func_register,
31 **kwargs,
32 }
---> 33 return pytorch_funcify(
34 fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs
35 )
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/pytorch/dispatch/basic.py:65, in pytorch_funcify_FunctionGraph(fgraph, node, fgraph_name, conversion_func, **kwargs)
56 @pytorch_funcify.register(FunctionGraph)
57 def pytorch_funcify_FunctionGraph(
58 fgraph,
(...) 62 **kwargs,
63 ):
64 built_kwargs = {"conversion_func": conversion_func, **kwargs}
---> 65 return fgraph_to_python(
66 fgraph,
67 conversion_func,
68 type_conversion_fn=pytorch_typify,
69 fgraph_name=fgraph_name,
70 **built_kwargs,
71 )
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
734 body_assigns = []
735 for node in order:
--> 736 compiled_func = op_conversion_fn(
737 node.op, node=node, storage_map=storage_map, **kwargs
738 )
740 # Create a local alias with a unique name
741 local_compiled_func_name = unique_name(compiled_func)
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/pytorch/linker.py:23, in PytorchLinker.fgraph_convert.<locals>.conversion_func_register(*args, **kwargs)
22 def conversion_func_register(*args, **kwargs):
---> 23 functor = pytorch_funcify(*args, **kwargs)
24 name = kwargs["unique_name"](functor)
25 self.gen_functors.append((f"_{name}", functor))
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/pytorch/dispatch/elemwise.py:16, in pytorch_funcify_Elemwise(op, node, **kwargs)
12 @pytorch_funcify.register(Elemwise)
13 def pytorch_funcify_Elemwise(op, node, **kwargs):
14 scalar_op = op.scalar_op
---> 16 base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
18 def check_special_scipy(func_name):
19 if "scipy." not in func_name:
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/pytorch/dispatch/scalar.py:30, in pytorch_funcify_ScalarOp(op, node, **kwargs)
28 nfunc_spec = getattr(op, "nfunc_spec", None)
29 if nfunc_spec is None:
---> 30 raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
32 func_name = nfunc_spec[0].replace("scipy.", "")
34 if "." in func_name:
NotImplementedError: Dispatch not implemented for Scalar Op clip
PyMC version information:
Python implementation: CPython
Python version : 3.12.10
IPython version : 9.2.0
pytensor: 2.30.3
arviz : 0.21.0
pymc : 5.22.0
numpyro : 0.18.0
blackjax: 0.0.0
nutpie : 0.14.3
pymc : 5.22.0
pandas : 2.2.3
arviz : 0.21.0
numpyro : 0.18.0
matplotlib: 3.10.3
numpy : 2.2.6
Context for the issue:
No response