diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 15008eec0e04..a5eb33d38acd 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -22,6 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate, logging +from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -108,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the - algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type - implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is - recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in - Stable Diffusion. + Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver` + type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the + `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) + paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided + sampling like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. @@ -186,7 +187,7 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: @@ -197,7 +198,7 @@ def __init__( else: raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") - if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero": + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": raise ValueError( f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead." ) @@ -493,10 +494,10 @@ def convert_model_output( "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) # DPM-Solver++ needs to solve an integral of the data prediction model. - if self.config.algorithm_type == "dpmsolver++": + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. - if self.config.variance_type in ["learned_range"]: + if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) @@ -517,34 +518,43 @@ def convert_model_output( x0_pred = self._threshold_sample(x0_pred) return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. - if self.config.variance_type in ["learned_range"]: - model_output = model_output[:, :3] - return model_output + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output elif self.config.prediction_type == "sample": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = (sample - alpha_t * model_output) / sigma_t - return epsilon elif self.config.prediction_type == "v_prediction": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = alpha_t * model_output + sigma_t * sample - return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverSinglestepScheduler." ) + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + def dpm_solver_first_order_update( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -594,6 +604,13 @@ def dpm_solver_first_order_update( x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) return x_t def singlestep_dpm_solver_second_order_update( @@ -601,6 +618,7 @@ def singlestep_dpm_solver_second_order_update( model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -688,6 +706,22 @@ def singlestep_dpm_solver_second_order_update( - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s1 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s1 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) return x_t def singlestep_dpm_solver_third_order_update( @@ -800,6 +834,7 @@ def singlestep_dpm_solver_update( *args, sample: torch.Tensor = None, order: int = None, + noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -848,9 +883,9 @@ def singlestep_dpm_solver_update( ) if order == 1: - return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample) + return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample, noise=noise) elif order == 2: - return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample) + return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise) elif order == 3: return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample) else: @@ -894,6 +929,7 @@ def step( model_output: torch.Tensor, timestep: int, sample: torch.Tensor, + generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -929,6 +965,13 @@ def step( self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output + if self.config.algorithm_type == "sde-dpmsolver++": + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None + order = self.order_list[self.step_index] # For img2img denoising might start with order>1 which is not possible @@ -940,9 +983,11 @@ def step( if order == 1: self.sample = sample - prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order) + prev_sample = self.singlestep_dpm_solver_update( + self.model_outputs, sample=self.sample, order=order, noise=noise + ) - # upon completion increase step index by one + # upon completion increase step index by one, noise=noise self._step_index += 1 if not return_dict: diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index ea43c210d650..873eaecd0a5c 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -194,16 +194,20 @@ def test_prediction_type(self): self.check_over_configs(prediction_type=prediction_type) def test_solver_order_and_type(self): - for algorithm_type in ["dpmsolver", "dpmsolver++"]: + for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]: for solver_type in ["midpoint", "heun"]: for order in [1, 2, 3]: for prediction_type in ["epsilon", "sample"]: - self.check_over_configs( - solver_order=order, - solver_type=solver_type, - prediction_type=prediction_type, - algorithm_type=algorithm_type, - ) + if algorithm_type == "sde-dpmsolver++": + if order == 3: + continue + else: + self.check_over_configs( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + algorithm_type=algorithm_type, + ) sample = self.full_loop( solver_order=order, solver_type=solver_type,