From eb696b99f051e53f15af11dad5b55e11350f7ef6 Mon Sep 17 00:00:00 2001 From: njindal Date: Thu, 6 Apr 2023 22:04:12 +0530 Subject: [PATCH 1/8] [2737]: Add Karras DPMSolverMultistepScheduler --- .../scheduling_dpmsolver_multistep.py | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 474d9b0d7339..7a08a63de8c3 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -114,7 +114,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. - + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -136,6 +139,7 @@ def __init__( algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -181,6 +185,7 @@ def __init__( self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 + self.use_karras_sigmas = use_karras_sigmas def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ @@ -199,6 +204,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic .copy() .astype(np.int64) ) + if self.use_karras_sigmas: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) self.model_outputs = [ None, @@ -217,6 +228,43 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: ) return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 + def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, self.num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: From c4c99f06123382a9bc1c9d3a21e502afd9121a13 Mon Sep 17 00:00:00 2001 From: njindal Date: Thu, 6 Apr 2023 23:21:31 +0530 Subject: [PATCH 2/8] [2737]: Add Karras DPMSolverMultistepScheduler --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 7a08a63de8c3..80e0623090e9 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -228,6 +228,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: ) return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) @@ -251,7 +252,7 @@ def _sigma_to_t(self, sigma, log_sigmas): t = t.reshape(sigma.shape) return t - # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 + # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" From 31da63b11bb0ca07408556ceff6a3358cc9bec96 Mon Sep 17 00:00:00 2001 From: njindal Date: Fri, 7 Apr 2023 20:18:25 +0530 Subject: [PATCH 3/8] Add test --- tests/schedulers/test_scheduler_dpm_multi.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 295bbe882746..59e500606107 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -209,6 +209,12 @@ def test_full_loop_with_v_prediction(self): assert abs(result_mean.item() - 0.2251) < 1e-3 + def test_full_loop_with_karras_and_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2096) < 1e-3 + def test_switch(self): # make sure that iterating over schedulers with same config names gives same results # for defaults From 2b8ad6e835e13fb0f8ceebc52df252454ef233fa Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Apr 2023 14:18:20 +0530 Subject: [PATCH 4/8] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 1320c41231a0..8b0ed74d138d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -260,7 +260,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: return sample - # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) @@ -284,7 +284,7 @@ def _sigma_to_t(self, sigma, log_sigmas): t = t.reshape(sigma.shape) return t - # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" From 2d9a326a314d71951cad6d8f5bb78e90468fee8d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Apr 2023 14:37:18 +0530 Subject: [PATCH 5/8] fix: repo consistency. --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 8b0ed74d138d..ef0aed55bbf9 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -203,7 +203,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic .copy() .astype(np.int64) ) - + if self.use_karras_sigmas: sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) @@ -260,7 +260,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: return sample - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras._sigma_to_t + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) From 56cc45aa587934f5324839bb847d6ddd60635276 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Apr 2023 14:41:11 +0530 Subject: [PATCH 6/8] remove Copied from statement from the set_timestep method. --- src/diffusers/schedulers/scheduling_deis_multistep.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 7aebda205e5b..8ea001a882d0 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -171,7 +171,6 @@ def __init__( self.model_outputs = [None] * solver_order self.lower_order_nums = 0 - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_timesteps def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. From 3995af651dd898409ec24d6e777dc2a7498be597 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Apr 2023 15:58:50 +0530 Subject: [PATCH 7/8] fix: test --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 6 +++--- src/diffusers/schedulers/scheduling_euler_discrete.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index ef0aed55bbf9..3399ee2c54cb 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -207,7 +207,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic if self.use_karras_sigmas: sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) - sigmas = self._convert_to_karras(in_sigmas=sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) @@ -285,14 +285,14 @@ def _sigma_to_t(self, sigma, log_sigmas): return t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper - ramp = np.linspace(0, 1, self.num_inference_steps) + ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index eea1d14eb4e7..7237128cbf07 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -206,7 +206,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) if self.use_karras_sigmas: - sigmas = self._convert_to_karras(in_sigmas=sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) @@ -241,14 +241,14 @@ def _sigma_to_t(self, sigma, log_sigmas): return t # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 - def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper - ramp = np.linspace(0, 1, self.num_inference_steps) + ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho From c854dbf7333e09f23ba2724cd6ec8d0edefbea2b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Apr 2023 15:59:49 +0530 Subject: [PATCH 8/8] Empty commit. Co-authored-by: njindal