Skip to content

Add the SDE variant of DPM-Solver and DPM-Solver++ to DPM Single Step #4251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 88 additions & 17 deletions src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from ..utils import logging, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


Expand Down Expand Up @@ -89,6 +89,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).

We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse
diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the
second-order `sde-dpmsolver++`.

[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
Expand Down Expand Up @@ -121,10 +125,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++`.
algorithm_type (`str`, default `dpmsolver++`):
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
sampling (e.g. stable-diffusion).
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or
`sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and
the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use
`dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion).
solver_type (`str`, default `midpoint`):
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
Expand Down Expand Up @@ -199,7 +203,7 @@ def __init__(
self.init_noise_sigma = 1.0

# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
Expand Down Expand Up @@ -390,10 +394,10 @@ def convert_model_output(
`torch.FloatTensor`: the converted model output.
"""
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
Expand All @@ -412,33 +416,42 @@ def convert_model_output(
x0_pred = self._threshold_sample(x0_pred)

return x0_pred

# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
return model_output
if self.config.variance_type in ["learned", "learned_range"]:
epsilon = model_output[:, :3]
else:
epsilon = model_output
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverSinglestepScheduler."
)

if self.config.thresholding:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t

return epsilon

def dpm_solver_first_order_update(
self,
model_output: torch.FloatTensor,
timestep: int,
prev_timestep: int,
sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
One step for the first-order DPM-Solver (equivalent to DDIM).
Expand All @@ -463,6 +476,20 @@ def dpm_solver_first_order_update(
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
x_t = (
(alpha_t / alpha_s) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t

def singlestep_dpm_solver_second_order_update(
Expand All @@ -471,6 +498,7 @@ def singlestep_dpm_solver_second_order_update(
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
One step for the second-order singlestep DPM-Solver.
Expand Down Expand Up @@ -524,6 +552,38 @@ def singlestep_dpm_solver_second_order_update(
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * (torch.exp(h) - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(alpha_t / alpha_s0) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t

def singlestep_dpm_solver_third_order_update(
Expand Down Expand Up @@ -604,6 +664,7 @@ def singlestep_dpm_solver_update(
prev_timestep: int,
sample: torch.FloatTensor,
order: int,
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
One step for the singlestep DPM-Solver.
Expand All @@ -622,10 +683,12 @@ def singlestep_dpm_solver_update(
`torch.FloatTensor`: the sample tensor at the previous timestep.
"""
if order == 1:
return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample)
return self.dpm_solver_first_order_update(
model_output_list[-1], timestep_list[-1], prev_timestep, sample, noise=noise
)
elif order == 2:
return self.singlestep_dpm_solver_second_order_update(
model_output_list, timestep_list, prev_timestep, sample
model_output_list, timestep_list, prev_timestep, sample, noise=noise
)
elif order == 3:
return self.singlestep_dpm_solver_third_order_update(
Expand All @@ -639,6 +702,7 @@ def step(
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Expand Down Expand Up @@ -675,6 +739,13 @@ def step(
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output

if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
else:
noise = None

order = self.order_list[step_index]

# For img2img denoising might start with order>1 which is not possible
Expand All @@ -688,7 +759,7 @@ def step(

timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep]
prev_sample = self.singlestep_dpm_solver_update(
self.model_outputs, timestep_list, prev_timestep, self.sample, order
self.model_outputs, timestep_list, prev_timestep, self.sample, order, noise=noise
)

if not return_dict:
Expand Down
18 changes: 11 additions & 7 deletions tests/schedulers/test_scheduler_dpm_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,20 @@ def test_prediction_type(self):
self.check_over_configs(prediction_type=prediction_type)

def test_solver_order_and_type(self):
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
for solver_type in ["midpoint", "heun"]:
for order in [1, 2, 3]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
if algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
if order == 3:
continue
else:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
sample = self.full_loop(
solver_order=order,
solver_type=solver_type,
Expand Down