diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 93975a27fc6e..e859c6e973d2 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import logging +from ..utils import logging, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -89,6 +89,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). + We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse + diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the + second-order `sde-dpmsolver++`. + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and @@ -121,10 +125,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++`. algorithm_type (`str`, default `dpmsolver++`): - the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the - algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in - https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided - sampling (e.g. stable-diffusion). + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or + `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and + the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use + `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). solver_type (`str`, default `midpoint`): the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are @@ -199,7 +203,7 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: @@ -390,10 +394,10 @@ def convert_model_output( `torch.FloatTensor`: the converted model output. """ # DPM-Solver++ needs to solve an integral of the data prediction model. - if self.config.algorithm_type == "dpmsolver++": + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. - if self.config.variance_type in ["learned_range"]: + if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t @@ -412,33 +416,42 @@ def convert_model_output( x0_pred = self._threshold_sample(x0_pred) return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. - elif self.config.algorithm_type == "dpmsolver": + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. - if self.config.variance_type in ["learned_range"]: - model_output = model_output[:, :3] - return model_output + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t - return epsilon elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample - return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverSinglestepScheduler." ) + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, timestep: int, prev_timestep: int, sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the first-order DPM-Solver (equivalent to DDIM). @@ -463,6 +476,20 @@ def dpm_solver_first_order_update( x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) return x_t def singlestep_dpm_solver_second_order_update( @@ -471,6 +498,7 @@ def singlestep_dpm_solver_second_order_update( timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the second-order singlestep DPM-Solver. @@ -524,6 +552,38 @@ def singlestep_dpm_solver_second_order_update( - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) return x_t def singlestep_dpm_solver_third_order_update( @@ -604,6 +664,7 @@ def singlestep_dpm_solver_update( prev_timestep: int, sample: torch.FloatTensor, order: int, + noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the singlestep DPM-Solver. @@ -622,10 +683,12 @@ def singlestep_dpm_solver_update( `torch.FloatTensor`: the sample tensor at the previous timestep. """ if order == 1: - return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample) + return self.dpm_solver_first_order_update( + model_output_list[-1], timestep_list[-1], prev_timestep, sample, noise=noise + ) elif order == 2: return self.singlestep_dpm_solver_second_order_update( - model_output_list, timestep_list, prev_timestep, sample + model_output_list, timestep_list, prev_timestep, sample, noise=noise ) elif order == 3: return self.singlestep_dpm_solver_third_order_update( @@ -639,6 +702,7 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, + generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -675,6 +739,13 @@ def step( self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None + order = self.order_list[step_index] # For img2img denoising might start with order>1 which is not possible @@ -688,7 +759,7 @@ def step( timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep] prev_sample = self.singlestep_dpm_solver_update( - self.model_outputs, timestep_list, prev_timestep, self.sample, order + self.model_outputs, timestep_list, prev_timestep, self.sample, order, noise=noise ) if not return_dict: diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 66be3d5d00ad..a40327ca10a7 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -175,16 +175,20 @@ def test_prediction_type(self): self.check_over_configs(prediction_type=prediction_type) def test_solver_order_and_type(self): - for algorithm_type in ["dpmsolver", "dpmsolver++"]: + for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: for solver_type in ["midpoint", "heun"]: for order in [1, 2, 3]: for prediction_type in ["epsilon", "sample"]: - self.check_over_configs( - solver_order=order, - solver_type=solver_type, - prediction_type=prediction_type, - algorithm_type=algorithm_type, - ) + if algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + if order == 3: + continue + else: + self.check_over_configs( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + algorithm_type=algorithm_type, + ) sample = self.full_loop( solver_order=order, solver_type=solver_type,