From dd8c2d9ba11e85b03a3fe6ecd2ede614bae71606 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 25 Jul 2023 08:10:53 +0530 Subject: [PATCH 01/11] Add the SDE variant of DPM-Solver and DPM-Solver++ to DPM Single Step --- .../scheduling_dpmsolver_singlestep.py | 105 +++++++++++++++--- tests/schedulers/test_scheduler_dpm_single.py | 18 +-- 2 files changed, 99 insertions(+), 24 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 93975a27fc6e..e859c6e973d2 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import logging +from ..utils import logging, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -89,6 +89,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). + We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse + diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the + second-order `sde-dpmsolver++`. + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and @@ -121,10 +125,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++`. algorithm_type (`str`, default `dpmsolver++`): - the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the - algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in - https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided - sampling (e.g. stable-diffusion). + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or + `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and + the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use + `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). solver_type (`str`, default `midpoint`): the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are @@ -199,7 +203,7 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: @@ -390,10 +394,10 @@ def convert_model_output( `torch.FloatTensor`: the converted model output. """ # DPM-Solver++ needs to solve an integral of the data prediction model. - if self.config.algorithm_type == "dpmsolver++": + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. - if self.config.variance_type in ["learned_range"]: + if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t @@ -412,33 +416,42 @@ def convert_model_output( x0_pred = self._threshold_sample(x0_pred) return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. - elif self.config.algorithm_type == "dpmsolver": + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. - if self.config.variance_type in ["learned_range"]: - model_output = model_output[:, :3] - return model_output + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t - return epsilon elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample - return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverSinglestepScheduler." ) + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, timestep: int, prev_timestep: int, sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the first-order DPM-Solver (equivalent to DDIM). @@ -463,6 +476,20 @@ def dpm_solver_first_order_update( x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) return x_t def singlestep_dpm_solver_second_order_update( @@ -471,6 +498,7 @@ def singlestep_dpm_solver_second_order_update( timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the second-order singlestep DPM-Solver. @@ -524,6 +552,38 @@ def singlestep_dpm_solver_second_order_update( - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) return x_t def singlestep_dpm_solver_third_order_update( @@ -604,6 +664,7 @@ def singlestep_dpm_solver_update( prev_timestep: int, sample: torch.FloatTensor, order: int, + noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the singlestep DPM-Solver. @@ -622,10 +683,12 @@ def singlestep_dpm_solver_update( `torch.FloatTensor`: the sample tensor at the previous timestep. """ if order == 1: - return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample) + return self.dpm_solver_first_order_update( + model_output_list[-1], timestep_list[-1], prev_timestep, sample, noise=noise + ) elif order == 2: return self.singlestep_dpm_solver_second_order_update( - model_output_list, timestep_list, prev_timestep, sample + model_output_list, timestep_list, prev_timestep, sample, noise=noise ) elif order == 3: return self.singlestep_dpm_solver_third_order_update( @@ -639,6 +702,7 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, + generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -675,6 +739,13 @@ def step( self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None + order = self.order_list[step_index] # For img2img denoising might start with order>1 which is not possible @@ -688,7 +759,7 @@ def step( timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep] prev_sample = self.singlestep_dpm_solver_update( - self.model_outputs, timestep_list, prev_timestep, self.sample, order + self.model_outputs, timestep_list, prev_timestep, self.sample, order, noise=noise ) if not return_dict: diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 66be3d5d00ad..a40327ca10a7 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -175,16 +175,20 @@ def test_prediction_type(self): self.check_over_configs(prediction_type=prediction_type) def test_solver_order_and_type(self): - for algorithm_type in ["dpmsolver", "dpmsolver++"]: + for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: for solver_type in ["midpoint", "heun"]: for order in [1, 2, 3]: for prediction_type in ["epsilon", "sample"]: - self.check_over_configs( - solver_order=order, - solver_type=solver_type, - prediction_type=prediction_type, - algorithm_type=algorithm_type, - ) + if algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + if order == 3: + continue + else: + self.check_over_configs( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + algorithm_type=algorithm_type, + ) sample = self.full_loop( solver_order=order, solver_type=solver_type, From 1b03f42b0b28aa76b3201bbabab662804c7d1e67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 24 May 2024 13:07:13 +0300 Subject: [PATCH 02/11] Fix `noise` parameter Co-authored-by: cmdr2 --- .../scheduling_dpmsolver_singlestep.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index e393fa042dc7..5eca3a9ad811 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -108,11 +108,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the - algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type - implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is - recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in - Stable Diffusion. + Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or `sde-dpmsolver++`. + The `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) + paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling + like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. @@ -553,8 +553,8 @@ def dpm_solver_first_order_update( model_output: torch.Tensor, *args, sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, **kwargs, - noise: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: """ One step for the first-order DPMSolver (equivalent to DDIM). @@ -624,8 +624,8 @@ def singlestep_dpm_solver_second_order_update( model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, **kwargs, - noise: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: """ One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the @@ -856,8 +856,8 @@ def singlestep_dpm_solver_update( *args, sample: torch.Tensor = None, order: int = None, + noise: Optional[torch.Tensor] = None, **kwargs, - noise: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: """ One step for the singlestep DPMSolver. @@ -1007,7 +1007,7 @@ def step( if order == 1: self.sample = sample - prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order) + prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order, noise=noise) # upon completion increase step index by one, noise=noise self._step_index += 1 From 30bf51a44cd25a8dab9b9e6c20cc7e21cf41a71c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 24 May 2024 17:00:04 +0300 Subject: [PATCH 03/11] singlestep seems to prefer s1, s2 in 2., 3. orders rather than s0 like in multistep --- .../schedulers/scheduling_dpmsolver_singlestep.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 5eca3a9ad811..6c0de575ffd9 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -716,14 +716,14 @@ def singlestep_dpm_solver_second_order_update( assert noise is not None if self.config.solver_type == "midpoint": x_t = ( - (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (sigma_t / sigma_s1 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( - (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (sigma_t / sigma_s1 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise @@ -732,14 +732,14 @@ def singlestep_dpm_solver_second_order_update( assert noise is not None if self.config.solver_type == "midpoint": x_t = ( - (alpha_t / alpha_s0) * sample + (alpha_t / alpha_s1) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * (torch.exp(h) - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) elif self.config.solver_type == "heun": x_t = ( - (alpha_t / alpha_s0) * sample + (alpha_t / alpha_s1) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise From a21e518741ac3e343f9be3cacb842ace67d729e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 24 May 2024 17:15:01 +0300 Subject: [PATCH 04/11] refactor: Move randn_tensor import to torch_utils module --- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 6c0de575ffd9..568c58cafdec 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -21,7 +21,8 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import deprecate, logging, randn_tensor +from ..utils import deprecate, logging +from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput From 894ffb070a78e1e669d95a7d51ad41113f418d47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 24 May 2024 17:17:29 +0300 Subject: [PATCH 05/11] make style --- .../schedulers/scheduling_dpmsolver_singlestep.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 568c58cafdec..e5f829451455 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -111,9 +111,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): algorithm_type (`str`, defaults to `dpmsolver++`): Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) - paper, and the `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) - paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling - like in Stable Diffusion. + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. @@ -906,9 +906,7 @@ def singlestep_dpm_solver_update( ) if order == 1: - return self.dpm_solver_first_order_update( - model_output_list[-1], sample=sample, noise=noise - ) + return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample, noise=noise) elif order == 2: return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise) elif order == 3: @@ -1008,7 +1006,9 @@ def step( if order == 1: self.sample = sample - prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order, noise=noise) + prev_sample = self.singlestep_dpm_solver_update( + self.model_outputs, sample=self.sample, order=order, noise=noise + ) # upon completion increase step index by one, noise=noise self._step_index += 1 From ea7db0d0c1fe83192bb75f36add16503c8575f9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 24 May 2024 18:58:17 +0300 Subject: [PATCH 06/11] Refactor deprecation warning and error handling logic on `algorithm_type` check in DPMSolverSinglestepScheduler class --- .../schedulers/scheduling_dpmsolver_singlestep.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index e5f829451455..3a59efb3d2ba 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -158,9 +158,9 @@ def __init__( lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): - if algorithm_type == "dpmsolver": - deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" - deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message) + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -198,7 +198,7 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero": + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": raise ValueError( f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead." ) From 4123c8613c5e340a0807cd6f45cdd8a27d6a16ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 24 May 2024 19:08:35 +0300 Subject: [PATCH 07/11] Fix typos --- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 3a59efb3d2ba..6cda0c3137e2 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -173,7 +173,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) @@ -191,12 +191,12 @@ def __init__( if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": raise ValueError( From 0648ee26f6b76b14da73210e128c6aa7fce00674 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 24 May 2024 20:17:28 +0300 Subject: [PATCH 08/11] Revert "Fix typos" This reverts commit 4123c8613c5e340a0807cd6f45cdd8a27d6a16ee. --- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 6cda0c3137e2..3a59efb3d2ba 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -173,7 +173,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) @@ -191,12 +191,12 @@ def __init__( if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: - raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: - raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": raise ValueError( From d59a5ea0e70b5820cae7b79fdb4ec33030f02c4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 11 Jul 2024 22:03:22 +0300 Subject: [PATCH 09/11] Remove `sde-dpmsolver` --- .../scheduling_dpmsolver_singlestep.py | 43 +++++-------------- 1 file changed, 10 insertions(+), 33 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 6cda0c3137e2..b5e2203af334 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -109,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or `sde-dpmsolver++`. - The `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) - paper, and the `dpmsolver++` type implements the algorithms in the - [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or - `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver` + type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the + `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) + paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided + sampling like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. @@ -158,9 +158,9 @@ def __init__( lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): - if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if algorithm_type == "dpmsolver": deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" - deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + deprecate("algorithm_types dpmsolver", "1.0.0", deprecation_message) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -187,7 +187,7 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: @@ -520,7 +520,7 @@ def convert_model_output( return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. - elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: @@ -611,13 +611,6 @@ def dpm_solver_first_order_update( + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - x_t = ( - (alpha_t / alpha_s) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) return x_t def singlestep_dpm_solver_second_order_update( @@ -729,22 +722,6 @@ def singlestep_dpm_solver_second_order_update( + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s1) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * (torch.exp(h) - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s1) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) return x_t def singlestep_dpm_solver_third_order_update( @@ -988,7 +965,7 @@ def step( self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output - if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + if self.config.algorithm_type == "sde-dpmsolver++": noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) From 890e93fd1f9e90469b33034621c89050eeb8d337 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 11 Jul 2024 22:16:01 +0300 Subject: [PATCH 10/11] chore: Update deprecation message for algorithm_type `dpmsolver` --- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index b5e2203af334..a5eb33d38acd 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -159,8 +159,8 @@ def __init__( variance_type: Optional[str] = None, ): if algorithm_type == "dpmsolver": - deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" - deprecate("algorithm_types dpmsolver", "1.0.0", deprecation_message) + deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) From 937689cddd59e13dc2f6019b547e6cfd413f16bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 11 Jul 2024 22:16:52 +0300 Subject: [PATCH 11/11] Refactor DPMSolverSinglestepSchedulerTest to exclude sde-dpmsolver from algorithm_type loop --- tests/schedulers/test_scheduler_dpm_single.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index f5879d25df71..873eaecd0a5c 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -194,11 +194,11 @@ def test_prediction_type(self): self.check_over_configs(prediction_type=prediction_type) def test_solver_order_and_type(self): - for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]: for solver_type in ["midpoint", "heun"]: for order in [1, 2, 3]: for prediction_type in ["epsilon", "sample"]: - if algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "sde-dpmsolver++": if order == 3: continue else: