From d23cd97470964d52c714e553c733f4184c69a6dd Mon Sep 17 00:00:00 2001 From: njindal Date: Tue, 4 Apr 2023 01:59:26 +0530 Subject: [PATCH 1/6] [2905]: Add Karras pattern to discrete euler --- .../schedulers/scheduling_euler_discrete.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index d6252904fd9a..217d815072c2 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -103,6 +103,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): interpolation_type (`str`, default `"linear"`, optional): interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of [`"linear"`, `"log_linear"`]. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to + `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M + Karras`. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -118,6 +122,7 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", interpolation_type: str = "linear", + use_karras_sigmas: Optional[bool] = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -149,6 +154,7 @@ def __init__( timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.is_scale_input_called = False + self.use_karras_sigmas = use_karras_sigmas def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] @@ -187,6 +193,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) if self.config.interpolation_type == "linear": sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) @@ -198,6 +205,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic " 'linear' or 'log_linear'" ) + if self.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): @@ -206,6 +217,45 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic else: self.timesteps = torch.from_numpy(timesteps).to(device=device) + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 + def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + print(sigma_min, sigma_max) + + rho = 7.0 + # ramp = torch.linspace(0, 1, self.num_inference_steps) + ramp = np.linspace(0, 1, self.num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + def step( self, model_output: torch.FloatTensor, From 1bf8bdcb2b7759f08e6b9c43b9e6af01b8822fb8 Mon Sep 17 00:00:00 2001 From: njindal Date: Tue, 4 Apr 2023 02:34:08 +0530 Subject: [PATCH 2/6] [2905]: Add Karras pattern to discrete euler --- tests/schedulers/test_scheduler_euler.py | 27 ++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py index 4d521b0075e1..651fc10c201b 100644 --- a/tests/schedulers/test_scheduler_euler.py +++ b/tests/schedulers/test_scheduler_euler.py @@ -117,3 +117,30 @@ def test_full_loop_device(self): assert abs(result_sum.item() - 10.0807) < 1e-2 assert abs(result_mean.item() - 0.0131) < 1e-3 + + def test_full_loop_device_karras_sigmas(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + generator = torch.manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + print(result_sum.item(), result_mean.item()) + assert abs(result_sum.item() - 124.52299499511719) < 1e-2 + assert abs(result_mean.item() - 0.16213932633399963) < 1e-3 From bb2eebb571326e35c1f6e995b811ecc128a2ac2c Mon Sep 17 00:00:00 2001 From: njindal Date: Tue, 4 Apr 2023 08:08:51 +0530 Subject: [PATCH 3/6] Review comments --- .../schedulers/scheduling_euler_discrete.py | 43 ++----------------- src/diffusers/schedulers/scheduling_utils.py | 24 +++++++++++ tests/schedulers/test_scheduler_euler.py | 2 +- 3 files changed, 28 insertions(+), 41 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 217d815072c2..cf1c4791c5bf 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -18,6 +18,7 @@ import numpy as np import torch +from k_diffusion.sampling import get_sigmas_karras from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging, randn_tensor @@ -206,7 +207,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) if self.use_karras_sigmas: - sigmas = self._convert_to_karras(in_sigmas=sigmas) + # get_sigmas_karras also append 0.0 to the end of the list, so chopping the last element + sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigmas.min(), sigma_max=sigmas.max())[:-1] timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) @@ -217,45 +219,6 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic else: self.timesteps = torch.from_numpy(timesteps).to(device=device) - def _sigma_to_t(self, sigma, log_sigmas): - # get log sigma - log_sigma = np.log(sigma) - - # get distribution - dists = log_sigma - log_sigmas[:, np.newaxis] - - # get sigmas range - low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = log_sigmas[low_idx] - high = log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = np.clip(w, 0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.reshape(sigma.shape) - return t - - # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 - def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: - """Constructs the noise schedule of Karras et al. (2022).""" - - sigma_min: float = in_sigmas[-1].item() - sigma_max: float = in_sigmas[0].item() - print(sigma_min, sigma_max) - - rho = 7.0 - # ramp = torch.linspace(0, 1, self.num_inference_steps) - ramp = np.linspace(0, 1, self.num_inference_steps) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return sigmas - def step( self, model_output: torch.FloatTensor, diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index a4121f75d850..1d9666518e8c 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -18,6 +18,7 @@ from typing import Any, Dict, Optional, Union import torch +import numpy as np from ..utils import BaseOutput @@ -174,3 +175,26 @@ def _get_compatibles(cls): getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes + @classmethod + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py index 651fc10c201b..aa46ef31885a 100644 --- a/tests/schedulers/test_scheduler_euler.py +++ b/tests/schedulers/test_scheduler_euler.py @@ -141,6 +141,6 @@ def test_full_loop_device_karras_sigmas(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - print(result_sum.item(), result_mean.item()) + assert abs(result_sum.item() - 124.52299499511719) < 1e-2 assert abs(result_mean.item() - 0.16213932633399963) < 1e-3 From 0e09934e458fa01cd6317bd024edbaca0e6d3452 Mon Sep 17 00:00:00 2001 From: njindal Date: Tue, 4 Apr 2023 08:16:17 +0530 Subject: [PATCH 4/6] Review comments --- src/diffusers/schedulers/scheduling_euler_discrete.py | 2 +- src/diffusers/schedulers/scheduling_utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index cf1c4791c5bf..bb39d620389a 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -107,7 +107,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M - Karras`. + Karras`. Please see equation (5) https://arxiv.org/pdf/2206.00364.pdf for more details. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 1d9666518e8c..e352f0d89b03 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -17,8 +17,8 @@ from enum import Enum from typing import Any, Dict, Optional, Union -import torch import numpy as np +import torch from ..utils import BaseOutput @@ -175,6 +175,7 @@ def _get_compatibles(cls): getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes + @classmethod def _sigma_to_t(self, sigma, log_sigmas): # get log sigma From 6ea067f62c595d4e437d6f74130e74b051b8af59 Mon Sep 17 00:00:00 2001 From: njindal Date: Tue, 4 Apr 2023 20:33:36 +0530 Subject: [PATCH 5/6] Review comments --- .../schedulers/scheduling_euler_discrete.py | 41 +++++++++++++++++-- src/diffusers/schedulers/scheduling_utils.py | 25 ----------- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index bb39d620389a..10a4ba1f9fb3 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -18,7 +18,6 @@ import numpy as np import torch -from k_diffusion.sampling import get_sigmas_karras from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging, randn_tensor @@ -207,8 +206,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) if self.use_karras_sigmas: - # get_sigmas_karras also append 0.0 to the end of the list, so chopping the last element - sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigmas.min(), sigma_max=sigmas.max())[:-1] + sigmas = self._convert_to_karras(in_sigmas=sigmas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) @@ -219,6 +217,43 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic else: self.timesteps = torch.from_numpy(timesteps).to(device=device) + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 + def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, self.num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + def step( self, model_output: torch.FloatTensor, diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index e352f0d89b03..a4121f75d850 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -17,7 +17,6 @@ from enum import Enum from typing import Any, Dict, Optional, Union -import numpy as np import torch from ..utils import BaseOutput @@ -175,27 +174,3 @@ def _get_compatibles(cls): getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes - - @classmethod - def _sigma_to_t(self, sigma, log_sigmas): - # get log sigma - log_sigma = np.log(sigma) - - # get distribution - dists = log_sigma - log_sigmas[:, np.newaxis] - - # get sigmas range - low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = log_sigmas[low_idx] - high = log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = np.clip(w, 0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.reshape(sigma.shape) - return t From fbc1fde9cd85d56a4df77a82de7b9a2724af8e63 Mon Sep 17 00:00:00 2001 From: njindal Date: Thu, 6 Apr 2023 12:45:12 +0530 Subject: [PATCH 6/6] Review comments --- src/diffusers/schedulers/scheduling_euler_discrete.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 10a4ba1f9fb3..df84dd6fd65d 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -104,9 +104,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of [`"linear"`, `"log_linear"`]. use_karras_sigmas (`bool`, *optional*, defaults to `False`): - Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to - `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M - Karras`. Please see equation (5) https://arxiv.org/pdf/2206.00364.pdf for more details. + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers]