Skip to content

BUG: PyTorch mode sampling failure #1478

Open
@fonnesbeck

Description

@fonnesbeck

Describe the issue:

Sampling a GP model using PyTensor with compile_kwargs={"mode": "pytorch"} results in a NotImplementedError associated with the clip Op. Error below.

Reproduceable code example:

def build_gp_latent_dataset(n=200, random_seed=42):
    """
    Generate data from a Gaussian Process with Student-T distributed noise.
    
    This creates a challenging latent variable problem that tests the samplers'
    ability to efficiently explore the high-dimensional posterior over the 
    latent GP function values.
    """
    rng_local = np.random.default_rng(random_seed)
    
    # Input locations
    X = np.linspace(0, 10, n)[:, None]
    
    # True GP hyperparameters
    ell_true = 1.0  # lengthscale
    eta_true = 4.0  # scale
    
    # Create true covariance function and sample from GP prior
    cov_func = eta_true**2 * pm.gp.cov.ExpQuad(1, ell_true)
    mean_func = pm.gp.mean.Zero()
    
    # Sample latent function values from GP prior with jitter for numerical stability
    K = cov_func(X).eval()
    # Add jitter to diagonal for numerical stability
    K += 1e-6 * np.eye(n)
    
    f_true = pm.draw(pm.MvNormal.dist(mu=mean_func(X), cov=K), 1, random_seed=rng_local)
    
    # Add Student-T distributed noise (heavier tails than Gaussian)
    sigma_true = 1.0
    nu_true = 5.0  # degrees of freedom
    y = f_true + sigma_true * rng_local.standard_t(df=nu_true, size=n)
    
    print(f"Generated GP data with {n} points")
    print(f"True hyperparameters: lengthscale={ell_true}, scale={eta_true}")
    print(f"Noise: σ={sigma_true}, ν={nu_true} (Student-T)")
    
    return X, y, f_true

# Generate the challenging GP dataset
N = 100  # number of data points
X, y_obs, f_true = build_gp_latent_dataset(N)

with pm.Model() as model:
    ell = pm.Gamma("ell", alpha=2, beta=1)  
    eta = pm.HalfNormal("eta", sigma=5) 
        
    cov = eta**2 * pm.gp.cov.ExpQuad(1, ell)
    gp = pm.gp.Latent(cov_func=cov)
        
    f = gp.prior("f", X=X)
        
    sigma = pm.HalfNormal("sigma", sigma=2.0)
    nu = 1 + pm.Gamma("nu", alpha=2, beta=0.1) 
        
    _ = pm.StudentT("y", mu=f, lam=1.0/sigma, nu=nu, observed=y_obs)

    idata_pytensor_pytorch = pm.sample(draws=500, tune=1000, chains=4, compile_kwargs={"mode": "pytorch"}, progressbar=False)

Error message:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[20], line 8
      6 with TimingContext("PyTensor PyTorch"):
      7     with model:
----> 8         idata_pytensor_pytorch = pm.sample(draws=n_draws, tune=n_tune, chains=n_chains, compile_kwargs={"mode": "pytorch"}, progressbar=False)
     10 ess_pytensor_pytorch = az.ess(idata_pytensor_pytorch)
     11 min_ess = min([ess_pytensor_pytorch[var].values.min() for var in ess_pytensor_pytorch.data_vars])

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:832, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    830         [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
    831     with joined_blas_limiter():
--> 832         initial_points, step = init_nuts(
    833             init=init,
    834             chains=chains,
    835             n_init=n_init,
    836             model=model,
    837             random_seed=random_seed_list,
    838             progressbar=progress_bool,
    839             jitter_max_retries=jitter_max_retries,
    840             tune=tune,
    841             initvals=initvals,
    842             compile_kwargs=compile_kwargs,
    843             **kwargs,
    844         )
    845 else:
    846     # Get initial points
    847     ipfns = make_initial_point_fns_per_chain(
    848         model=model,
    849         overrides=initvals,
    850         jitter_rvs=set(),
    851         chains=chains,
    852     )

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1598, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, compile_kwargs, **kwargs)
   1592 if "advi" in init:
   1593     cb = [
   1594         pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
   1595         pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
   1596     ]
-> 1598 logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs)
   1599 logp_dlogp_func.trust_input = True
   1601 def model_logp_fn(ip: PointType) -> np.ndarray:

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pymc/model/core.py:572, in Model.logp_dlogp_function(self, grad_vars, tempered, initial_point, ravel_inputs, **kwargs)
    566     initial_point = self.initial_point(0)
    567 extra_vars_and_values = {
    568     var: initial_point[var.name]
    569     for var in self.value_vars
    570     if var in input_vars and var not in grad_vars
    571 }
--> 572 return ValueGradFunction(
    573     costs,
    574     grad_vars,
    575     extra_vars_and_values,
    576     model=self,
    577     initial_point=initial_point,
    578     ravel_inputs=ravel_inputs,
    579     **kwargs,
    580 )

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pymc/model/core.py:256, in ValueGradFunction.__init__(self, costs, grad_vars, extra_vars_and_values, dtype, casting, compute_grads, model, initial_point, ravel_inputs, **kwargs)
    250         warnings.warn(
    251             "ValueGradFunction will become a function of raveled inputs.\n"
    252             "Specify `ravel_inputs` to suppress this warning. Note that setting `ravel_inputs=False` will be forbidden in a future release."
    253         )
    254     inputs = grad_vars
--> 256 self._pytensor_function = compile(inputs, outputs, givens=givens, **kwargs)
    257 self._raveled_inputs = ravel_inputs

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pymc/pytensorf.py:947, in compile(inputs, outputs, random_seed, mode, **kwargs)
    945 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
    946 mode = Mode(linker=mode.linker, optimizer=opt_qry)
--> 947 pytensor_function = pytensor.function(
    948     inputs,
    949     outputs,
    950     updates={**rng_updates, **kwargs.pop("updates", {})},
    951     mode=mode,
    952     **kwargs,
    953 )
    954 return pytensor_function

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/compile/function/__init__.py:332, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, trust_input)
    321     fn = orig_function(
    322         inputs,
    323         outputs,
   (...)    327         trust_input=trust_input,
    328     )
    329 else:
    330     # note: pfunc will also call orig_function -- orig_function is
    331     #      a choke point that all compilation must pass through
--> 332     fn = pfunc(
    333         params=inputs,
    334         outputs=outputs,
    335         mode=mode,
    336         updates=updates,
    337         givens=givens,
    338         no_default_updates=no_default_updates,
    339         accept_inplace=accept_inplace,
    340         name=name,
    341         rebuild_strict=rebuild_strict,
    342         allow_input_downcast=allow_input_downcast,
    343         on_unused_input=on_unused_input,
    344         profile=profile,
    345         output_keys=output_keys,
    346         trust_input=trust_input,
    347     )
    348 return fn

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:466, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph, trust_input)
    452     profile = ProfileStats(message=profile)
    454 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    455     params,
    456     outputs,
   (...)    463     fgraph=fgraph,
    464 )
--> 466 return orig_function(
    467     inputs,
    468     cloned_outputs,
    469     mode,
    470     accept_inplace=accept_inplace,
    471     name=name,
    472     profile=profile,
    473     on_unused_input=on_unused_input,
    474     output_keys=output_keys,
    475     fgraph=fgraph,
    476     trust_input=trust_input,
    477 )

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/compile/function/types.py:1833, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph, trust_input)
   1820     m = Maker(
   1821         inputs,
   1822         outputs,
   (...)   1830         trust_input=trust_input,
   1831     )
   1832     with config.change_flags(compute_test_value="off"):
-> 1833         fn = m.create(defaults)
   1834 finally:
   1835     if profile and fn:

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/compile/function/types.py:1717, in FunctionMaker.create(self, input_storage, storage_map)
   1714 start_import_time = pytensor.link.c.cmodule.import_time
   1716 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1717     _fn, _i, _o = self.linker.make_thunk(
   1718         input_storage=input_storage_lists, storage_map=storage_map
   1719     )
   1721 end_linker = time.perf_counter()
   1723 linker_time = end_linker - start_linker

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
    238 def make_thunk(
    239     self,
    240     input_storage: Optional["InputStorageType"] = None,
   (...)    243     **kwargs,
    244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 245     return self.make_all(
    246         input_storage=input_storage,
    247         output_storage=output_storage,
    248         storage_map=storage_map,
    249     )[:3]

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
    692 for k in storage_map:
    693     compute_map[k] = [k.owner is None]
--> 695 thunks, nodes, jit_fn = self.create_jitable_thunk(
    696     compute_map, nodes, input_storage, output_storage, storage_map
    697 )
    699 [fn] = thunks
    700 fn.jit_fn = jit_fn

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
    644 # This is a bit hackish, but we only return one of the output nodes
    645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
    648     self.fgraph,
    649     order=order,
    650     input_storage=input_storage,
    651     output_storage=output_storage,
    652     storage_map=storage_map,
    653 )
    655 thunk_inputs = self.create_thunk_inputs(storage_map)
    656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/pytorch/linker.py:33, in PytorchLinker.fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs)
     26     return functor
     28 built_kwargs = {
     29     "unique_name": generator,
     30     "conversion_func": conversion_func_register,
     31     **kwargs,
     32 }
---> 33 return pytorch_funcify(
     34     fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs
     35 )

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/pytorch/dispatch/basic.py:65, in pytorch_funcify_FunctionGraph(fgraph, node, fgraph_name, conversion_func, **kwargs)
     56 @pytorch_funcify.register(FunctionGraph)
     57 def pytorch_funcify_FunctionGraph(
     58     fgraph,
   (...)     62     **kwargs,
     63 ):
     64     built_kwargs = {"conversion_func": conversion_func, **kwargs}
---> 65     return fgraph_to_python(
     66         fgraph,
     67         conversion_func,
     68         type_conversion_fn=pytorch_typify,
     69         fgraph_name=fgraph_name,
     70         **built_kwargs,
     71     )

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
    734 body_assigns = []
    735 for node in order:
--> 736     compiled_func = op_conversion_fn(
    737         node.op, node=node, storage_map=storage_map, **kwargs
    738     )
    740     # Create a local alias with a unique name
    741     local_compiled_func_name = unique_name(compiled_func)

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/pytorch/linker.py:23, in PytorchLinker.fgraph_convert.<locals>.conversion_func_register(*args, **kwargs)
     22 def conversion_func_register(*args, **kwargs):
---> 23     functor = pytorch_funcify(*args, **kwargs)
     24     name = kwargs["unique_name"](functor)
     25     self.gen_functors.append((f"_{name}", functor))

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/pytorch/dispatch/elemwise.py:16, in pytorch_funcify_Elemwise(op, node, **kwargs)
     12 @pytorch_funcify.register(Elemwise)
     13 def pytorch_funcify_Elemwise(op, node, **kwargs):
     14     scalar_op = op.scalar_op
---> 16     base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
     18     def check_special_scipy(func_name):
     19         if "scipy." not in func_name:

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/repos/pymc-examples/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/pytorch/dispatch/scalar.py:30, in pytorch_funcify_ScalarOp(op, node, **kwargs)
     28 nfunc_spec = getattr(op, "nfunc_spec", None)
     29 if nfunc_spec is None:
---> 30     raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
     32 func_name = nfunc_spec[0].replace("scipy.", "")
     34 if "." in func_name:

NotImplementedError: Dispatch not implemented for Scalar Op clip

PyMC version information:

Python implementation: CPython
Python version : 3.12.10
IPython version : 9.2.0

pytensor: 2.30.3
arviz : 0.21.0
pymc : 5.22.0
numpyro : 0.18.0
blackjax: 0.0.0
nutpie : 0.14.3

pymc : 5.22.0
pandas : 2.2.3
arviz : 0.21.0
numpyro : 0.18.0
matplotlib: 3.10.3
numpy : 2.2.6

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions