From 9653bb9838d4a31c95314d9b5693a2e83b1fceac Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Wed, 17 Apr 2024 18:50:34 +0200 Subject: [PATCH 01/16] Add VDMScheduler to diffusers/schedulers This commit adds a new scheduler called VDMScheduler to the diffusers/schedulers module. The VDMScheduler is a class that implements a scheduling algorithm for denoising models. It takes into account parameters such as the number of training timesteps, beta schedule, prediction type, timestep spacing, and steps offset. The VDMScheduler class provides methods for scaling model input, adding noise to samples, and stepping through the denoising process. The commit also includes a new file, scheduling_vdm.py, which contains the implementation of the VDMScheduler class. --- src/diffusers/schedulers/__init__.py | 2 + src/diffusers/schedulers/scheduling_vdm.py | 230 +++++++++++++++++++++ 2 files changed, 232 insertions(+) create mode 100644 src/diffusers/schedulers/scheduling_vdm.py diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 720d8ea25e29..2cb401b86538 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -70,6 +70,7 @@ _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"] _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"] + _import_structure["scheduling_vdm"] = ["VDMScheduler"] try: if not is_flax_available(): @@ -165,6 +166,7 @@ from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler + from .scheduling_vdm import VDMScheduler try: if not is_flax_available(): diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py new file mode 100644 index 000000000000..04c0c29e3b62 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -0,0 +1,230 @@ +# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union +from functools import partial, cache + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +@dataclass +class VDMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +class VDMScheduler(SchedulerMixin, ConfigMixin): + @register_to_config + def __init__(self, + num_train_timesteps: Optional[int] = None, + beta_schedule: str = "linear", + prediction_type: str = "epsilon", + timestep_spacing: str = "leading", + steps_offset: Union[int, float] = 0): + self.beta_start = 1e-4 + self.beta_end = 0.02 + + self.log_snr = partial(self._log_snr, beta_schedule=beta_schedule) + self._log_snr_cached = cache(self.log_snr) + + self.timesteps = torch.from_numpy(self.get_timesteps(num_train_timesteps or 1000)) + self.num_inference_steps = None + # Equivalent to 1 - self.sigmas + # For linear beta schedule equivalent to torch.exp(-1e-4 - 10 * t ** 2) + self.alphas_cumprod = lambda t: torch.sigmoid(self.log_snr(t)) + # Equivalent to 1 - self.alphas_cumprod + self.sigmas = lambda t: torch.sigmoid(-self.log_snr(t)) + + if num_train_timesteps is not None: + self.log_snr = self._log_snr_cached + + # Fixme: Might not be exact + alphas_cumprod = self.alphas_cumprod(torch.flip(self.timesteps, dims=(0,))) + alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] + self.alphas = torch.cat([alphas_cumprod[:1], alphas]) + self.betas = 1 - self.alphas + + def __len__(self) -> int: + return self.num_inference_steps or self.config.num_train_timesteps or len(self.timesteps) + + @staticmethod + def _log_snr(t: torch.FloatTensor, beta_schedule: str) -> torch.FloatTensor: + # From https://github.com/Zhengxinyang/LAS-Diffusion/blob/a7eb304a24dec2eb85a8d3899c73338e10435bba/network/model_utils.py#L345 + if beta_schedule == "linear": + return -torch.log(torch.special.expm1(1e-4 + 10 * t ** 2)) + elif beta_schedule == "squaredcos_cap_v2": + return -torch.log(torch.clamp((torch.cos((t + 0.008) / (1 + 0.008) * math.pi * 0.5) ** -2) - 1, min=1e-5)) + elif beta_schedule == "sigmoid": + return 6 - 12 * t + + raise NotImplementedError(f"{beta_schedule} does is not implemented for {VDMScheduler.__class__}") + + def get_timesteps(self, num_steps: Optional[int] = None) -> np.ndarray: + if num_steps is None: + num_steps = self.config.num_train_timesteps + if self.config.timestep_spacing in ["linspace", "leading"]: + timesteps = np.linspace(0, 1, num_steps, + endpoint=self.config.timestep_spacing == "linspace")[::-1] + elif self.config.time_spacing == "trailing": + timesteps = np.arange(1, 0, -1 / num_steps) - 1 / num_steps + else: + raise ValueError(f"`{self.config.timestep_spacing}` timestep spacing is not supported." + "Choose one of 'linspace', 'leading' or 'trailing'.") + return timesteps.copy() + + def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None): + if self.config.num_train_timesteps is None: + timesteps = self.get_timesteps(num_inference_steps) + else: + if self.config.timestep_spacing in ["linspace", "leading"]: + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, + endpoint=self.config.timestep_spacing == "linspace")[::-1] + elif self.config.timestep_spacing == "trailing": + timesteps = np.arange(self.config.num_train_timesteps, + 0, + -self.config.num_train_timesteps / num_inference_steps) - 1 + else: + raise ValueError(f"`{self.config.timestep_spacing}` timestep spacing is not supported." + "Choose one of 'linspace', 'leading' or 'trailing'.") + timesteps = timesteps.round().astype(np.int64).copy() + + self.num_inference_steps = num_inference_steps + timesteps += self.config.steps_offset + self.timesteps = torch.from_numpy(timesteps).to(device) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + def add_noise(self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: Union[torch.FloatTensor, torch.IntTensor, torch.LongTensor]) -> torch.FloatTensor: + + if isinstance(timesteps, (torch.IntTensor, torch.LongTensor)): + if self.config.num_train_timesteps is None: + raise TypeError("Discrete timesteps require `num_train_timesteps` to be set.") + timesteps /= self.config.num_train_timesteps + + log_snr = self.log_snr(timesteps) + log_snr = log_snr.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) + + sqrt_alpha_prod = torch.sqrt(torch.sigmoid(log_snr)) + sqrt_one_minus_alpha_prod = torch.sqrt(torch.sigmoid(-log_snr)) # sqrt(sigma) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def step(self, + model_output: torch.FloatTensor, + timestep: Union[int, float, torch.FloatTensor, torch.IntTensor, torch.LongTensor], + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True) -> Union[VDMSchedulerOutput, Tuple]: + + if isinstance(timestep, (int, float)): + timestep = torch.tensor(timestep) + if isinstance(timestep, torch.FloatTensor): + prev_timestep = timestep - 1 / len(self) + elif isinstance(timestep, (torch.IntTensor, torch.LongTensor)): + T = len(self) + prev_timestep = (T - timestep - 1) / T + timestep = (T - timestep) / T + else: + raise TypeError(f"Unsupported type `{type(timestep)}` for `timestep`.") + + # 1. Compute current and previous alpha and sigma values + log_snr = self._log_snr_cached(timestep, beta_schedule=self.config.beta_schedule) + prev_log_snr = self._log_snr_cached(prev_timestep, beta_schedule=self.config.beta_schedule) + + alpha, sigma = torch.sigmoid(log_snr), torch.sigmoid(-log_snr) + prev_alpha, prev_sigma = torch.sigmoid(prev_log_snr), torch.sigmoid(-prev_log_snr) + + # 2. Compute predicted original sample x_0 and predicted previous sample x_{t-1} + c = -torch.expm1(log_snr - prev_log_snr) + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - torch.sqrt(sigma) * model_output) / torch.sqrt(alpha) + pred_prev_sample = torch.sqrt(prev_alpha / alpha) * (sample - c * torch.sqrt(sigma) * model_output) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_prev_sample = torch.sqrt(prev_alpha) * (sample * (1 - c) / torch.sqrt(alpha) + c * model_output) + else: + raise ValueError("`prediction_type` must be one of `epsilon`, `sample` or `v_prediction`.") + + # 3. Add noise + variance = 0 + if timestep > 0: + noise = randn_tensor(model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype) + variance = torch.sqrt(prev_sigma * c) * noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return VDMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) From e5439c3ed2d8034e5b443108ba26ce8b4b895a63 Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Thu, 18 Apr 2024 09:29:43 +0200 Subject: [PATCH 02/16] Refactor VDMScheduler class in scheduling_vdm.py - Refactored the initialization of self.beta_start and self.beta_end as continuous schedules in self._log_snr are fitted to these values. - Added self.num_inference_steps and self.timesteps attributes. - Updated the implementation of self.alphas_cumprod and self.sigmas. - Modified the implementation of self.log_snr to use the cached version for discrete timesteps and inference. - Updated the implementation of the beta_schedule methods. - Modified the implementation of the forward method to handle discrete timesteps correctly. - Updated the implementation of the forward method to handle different types of timestep inputs. These changes improve the functionality and readability of the VDMScheduler class in scheduling_vdm.py. --- src/diffusers/schedulers/scheduling_vdm.py | 42 ++++++++++++---------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index 04c0c29e3b62..ceea393d6869 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -54,24 +54,24 @@ def __init__(self, prediction_type: str = "epsilon", timestep_spacing: str = "leading", steps_offset: Union[int, float] = 0): + # Hardcoded as continuous schedules in self._log_snr are fitted to these values self.beta_start = 1e-4 self.beta_end = 0.02 + self.num_inference_steps = None + self.timesteps = torch.from_numpy(self.get_timesteps(num_train_timesteps or 1000)) + self.log_snr = partial(self._log_snr, beta_schedule=beta_schedule) - self._log_snr_cached = cache(self.log_snr) + self._log_snr_cached = cache(self.log_snr) # Cached version for discrete timesteps and inference - self.timesteps = torch.from_numpy(self.get_timesteps(num_train_timesteps or 1000)) - self.num_inference_steps = None - # Equivalent to 1 - self.sigmas # For linear beta schedule equivalent to torch.exp(-1e-4 - 10 * t ** 2) - self.alphas_cumprod = lambda t: torch.sigmoid(self.log_snr(t)) - # Equivalent to 1 - self.alphas_cumprod - self.sigmas = lambda t: torch.sigmoid(-self.log_snr(t)) + self.alphas_cumprod = lambda t: torch.sigmoid(self.log_snr(t)) # Equivalent to 1 - self.sigmas + self.sigmas = lambda t: torch.sigmoid(-self.log_snr(t)) # Equivalent to 1 - self.alphas_cumprod if num_train_timesteps is not None: - self.log_snr = self._log_snr_cached + self.log_snr = self._log_snr_cached # Use cached version for discrete timesteps - # Fixme: Might not be exact + # TODO: Might not be exact alphas_cumprod = self.alphas_cumprod(torch.flip(self.timesteps, dims=(0,))) alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] self.alphas = torch.cat([alphas_cumprod[:1], alphas]) @@ -88,7 +88,10 @@ def _log_snr(t: torch.FloatTensor, beta_schedule: str) -> torch.FloatTensor: elif beta_schedule == "squaredcos_cap_v2": return -torch.log(torch.clamp((torch.cos((t + 0.008) / (1 + 0.008) * math.pi * 0.5) ** -2) - 1, min=1e-5)) elif beta_schedule == "sigmoid": - return 6 - 12 * t + # From https://colab.research.google.com/github/google-research/vdm/blob/main/colab/SimpleDiffusionColab.ipynb + gamma_min = -6 # -13.3 in VDM CIFAR10 experiments + gamma_max = 6 # 5.0 in VDM CIFAR10 experiments + return gamma_max + (gamma_min - gamma_max) * t raise NotImplementedError(f"{beta_schedule} does is not implemented for {VDMScheduler.__class__}") @@ -163,10 +166,10 @@ def add_noise(self, noise: torch.FloatTensor, timesteps: Union[torch.FloatTensor, torch.IntTensor, torch.LongTensor]) -> torch.FloatTensor: - if isinstance(timesteps, (torch.IntTensor, torch.LongTensor)): + if not timesteps.is_floating_point(): if self.config.num_train_timesteps is None: - raise TypeError("Discrete timesteps require `num_train_timesteps` to be set.") - timesteps /= self.config.num_train_timesteps + raise TypeError("Discrete timesteps require `self.config.num_train_timesteps` to be set.") + timesteps = timesteps / self.config.num_train_timesteps log_snr = self.log_snr(timesteps) log_snr = log_snr.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) @@ -181,19 +184,20 @@ def step(self, model_output: torch.FloatTensor, timestep: Union[int, float, torch.FloatTensor, torch.IntTensor, torch.LongTensor], sample: torch.FloatTensor, - generator=None, + generator: Optional[torch.Generator] = None, return_dict: bool = True) -> Union[VDMSchedulerOutput, Tuple]: if isinstance(timestep, (int, float)): - timestep = torch.tensor(timestep) - if isinstance(timestep, torch.FloatTensor): + timestep = torch.tensor(timestep, + dtype=torch.float32 if isinstance(timestep, float) else torch.int64, + device=sample.device) + + if timestep.is_floating_point(): prev_timestep = timestep - 1 / len(self) - elif isinstance(timestep, (torch.IntTensor, torch.LongTensor)): + else: T = len(self) prev_timestep = (T - timestep - 1) / T timestep = (T - timestep) / T - else: - raise TypeError(f"Unsupported type `{type(timestep)}` for `timestep`.") # 1. Compute current and previous alpha and sigma values log_snr = self._log_snr_cached(timestep, beta_schedule=self.config.beta_schedule) From 5a9bc6261e7fc19deab0c0c766b579a257e032b1 Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Thu, 18 Apr 2024 11:04:39 +0200 Subject: [PATCH 03/16] Refactor VDMScheduler class in scheduling_vdm.py - Refactored the `__len__` method to return a default value of 1000 if `num_inference_steps` and `num_train_timesteps` are not set. - Added input validation for `t` in the `_log_snr` method to ensure it is within the range [0, 1]. - Normalized `timesteps` to the range [0, 1] in the `__call__` method if `self.timesteps` is None. - Simplified the calculation of `prev_timestep` in the `__call__` method. - Updated the error message in the `__call__` method to specify that `prediction_type` must be either `epsilon` or `sample`. These changes improve the readability and maintainability of the code, and ensure proper input validation and normalization. --- src/diffusers/schedulers/scheduling_vdm.py | 26 +++++++++++++--------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index ceea393d6869..28534226112e 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -59,7 +59,7 @@ def __init__(self, self.beta_end = 0.02 self.num_inference_steps = None - self.timesteps = torch.from_numpy(self.get_timesteps(num_train_timesteps or 1000)) + self.timesteps = None self.log_snr = partial(self._log_snr, beta_schedule=beta_schedule) self._log_snr_cached = cache(self.log_snr) # Cached version for discrete timesteps and inference @@ -72,16 +72,20 @@ def __init__(self, self.log_snr = self._log_snr_cached # Use cached version for discrete timesteps # TODO: Might not be exact - alphas_cumprod = self.alphas_cumprod(torch.flip(self.timesteps, dims=(0,))) + timesteps = torch.from_numpy(self.get_timesteps(num_train_timesteps or 1000)) + alphas_cumprod = self.alphas_cumprod(torch.flip(timesteps, dims=(0,))) alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] self.alphas = torch.cat([alphas_cumprod[:1], alphas]) self.betas = 1 - self.alphas def __len__(self) -> int: - return self.num_inference_steps or self.config.num_train_timesteps or len(self.timesteps) + return self.num_inference_steps or self.config.num_train_timesteps or 1000 @staticmethod def _log_snr(t: torch.FloatTensor, beta_schedule: str) -> torch.FloatTensor: + if t.min() < 0 or t.max() > 1: + raise ValueError("`t` must be in range [0, 1].") + # From https://github.com/Zhengxinyang/LAS-Diffusion/blob/a7eb304a24dec2eb85a8d3899c73338e10435bba/network/model_utils.py#L345 if beta_schedule == "linear": return -torch.log(torch.special.expm1(1e-4 + 10 * t ** 2)) @@ -169,7 +173,7 @@ def add_noise(self, if not timesteps.is_floating_point(): if self.config.num_train_timesteps is None: raise TypeError("Discrete timesteps require `self.config.num_train_timesteps` to be set.") - timesteps = timesteps / self.config.num_train_timesteps + timesteps = timesteps / self.config.num_train_timesteps # Normalize to [0, 1] log_snr = self.log_snr(timesteps) log_snr = log_snr.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) @@ -192,12 +196,14 @@ def step(self, dtype=torch.float32 if isinstance(timestep, float) else torch.int64, device=sample.device) - if timestep.is_floating_point(): - prev_timestep = timestep - 1 / len(self) + if self.timesteps is None: + if not timestep.is_floating_point(): + timestep = timestep / len(self) + prev_timestep = (timestep - 1 / len(self)).clamp(0, 1) else: - T = len(self) - prev_timestep = (T - timestep - 1) / T - timestep = (T - timestep) / T + # index + 1 corresponds to t - 1 as timesteps are reversed + index = self.index_for_timestep(timestep) + prev_timestep = self.timesteps[index + 1] if index < len(self.timesteps) else 0 # 1. Compute current and previous alpha and sigma values log_snr = self._log_snr_cached(timestep, beta_schedule=self.config.beta_schedule) @@ -215,7 +221,7 @@ def step(self, pred_original_sample = model_output pred_prev_sample = torch.sqrt(prev_alpha) * (sample * (1 - c) / torch.sqrt(alpha) + c * model_output) else: - raise ValueError("`prediction_type` must be one of `epsilon`, `sample` or `v_prediction`.") + raise ValueError("`prediction_type` must be either `epsilon` or `sample`.") # 3. Add noise variance = 0 From cb4019e8f1bae8db08c13c270b36372377b32ff1 Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Thu, 18 Apr 2024 13:53:21 +0200 Subject: [PATCH 04/16] Refactor VDMScheduler in scheduling_vdm.py - Remove unused imports and variables - Move log_snr function outside of the class - Simplify the calculation of alphas_cumprod and sigmas - Remove redundant code for setting timesteps - Normalize discrete timesteps to [0, 1] - Simplify the computation of alpha, sigma, prev_alpha, and prev_sigma - Remove unnecessary check for timestep type - Simplify the noise calculation - Remove unnecessary variance initialization - Improve code readability and maintainability This commit refactors the VDMScheduler class in scheduling_vdm.py by removing unused imports and variables, moving the log_snr function outside of the class, simplifying the calculation of alphas_cumprod and sigmas, removing redundant code for setting timesteps, normalizing discrete timesteps to [0, 1], simplifying the computation of alpha, sigma, prev_alpha, and prev_sigma, removing unnecessary checks for timestep type, simplifying the noise calculation, removing unnecessary variance initialization, and improving code readability and maintainability. --- src/diffusers/schedulers/scheduling_vdm.py | 98 +++++++++++----------- 1 file changed, 48 insertions(+), 50 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index 28534226112e..835d3347318d 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union -from functools import partial, cache +from functools import partial import numpy as np import torch @@ -28,6 +28,24 @@ from .scheduling_utils import SchedulerMixin +def log_snr(t: torch.FloatTensor, beta_schedule: str) -> torch.FloatTensor: + if t.min() < 0 or t.max() > 1: + raise ValueError("`t` must be in range [0, 1].") + + # From https://github.com/Zhengxinyang/LAS-Diffusion/blob/a7eb304a24dec2eb85a8d3899c73338e10435bba/network/model_utils.py#L345 + if beta_schedule == "linear": + return -torch.log(torch.special.expm1(1e-4 + 10 * t ** 2)) + elif beta_schedule == "squaredcos_cap_v2": + return -torch.log(torch.clamp((torch.cos((t + 0.008) / (1 + 0.008) * math.pi * 0.5) ** -2) - 1, min=1e-5)) + elif beta_schedule == "sigmoid": + # From https://colab.research.google.com/github/google-research/vdm/blob/main/colab/SimpleDiffusionColab.ipynb + gamma_min = -6 # -13.3 in VDM CIFAR10 experiments + gamma_max = 6 # 5.0 in VDM CIFAR10 experiments + return gamma_max + (gamma_min - gamma_max) * t + + raise NotImplementedError(f"{beta_schedule} does is not implemented for {VDMScheduler.__class__}") + + @dataclass class VDMSchedulerOutput(BaseOutput): """ @@ -57,23 +75,20 @@ def __init__(self, # Hardcoded as continuous schedules in self._log_snr are fitted to these values self.beta_start = 1e-4 self.beta_end = 0.02 + self.init_noise_sigma = 1.0 - self.num_inference_steps = None - self.timesteps = None - - self.log_snr = partial(self._log_snr, beta_schedule=beta_schedule) - self._log_snr_cached = cache(self.log_snr) # Cached version for discrete timesteps and inference + self.log_snr = partial(log_snr, beta_schedule=beta_schedule) # For linear beta schedule equivalent to torch.exp(-1e-4 - 10 * t ** 2) self.alphas_cumprod = lambda t: torch.sigmoid(self.log_snr(t)) # Equivalent to 1 - self.sigmas self.sigmas = lambda t: torch.sigmoid(-self.log_snr(t)) # Equivalent to 1 - self.alphas_cumprod - if num_train_timesteps is not None: - self.log_snr = self._log_snr_cached # Use cached version for discrete timesteps - + self.num_inference_steps = None + self.timesteps = None + if num_train_timesteps: # TODO: Might not be exact - timesteps = torch.from_numpy(self.get_timesteps(num_train_timesteps or 1000)) - alphas_cumprod = self.alphas_cumprod(torch.flip(timesteps, dims=(0,))) + self.timesteps = torch.from_numpy(self.get_timesteps(len(self))) + alphas_cumprod = self.alphas_cumprod(torch.flip(self.timesteps, dims=(0,))) alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] self.alphas = torch.cat([alphas_cumprod[:1], alphas]) self.betas = 1 - self.alphas @@ -81,24 +96,6 @@ def __init__(self, def __len__(self) -> int: return self.num_inference_steps or self.config.num_train_timesteps or 1000 - @staticmethod - def _log_snr(t: torch.FloatTensor, beta_schedule: str) -> torch.FloatTensor: - if t.min() < 0 or t.max() > 1: - raise ValueError("`t` must be in range [0, 1].") - - # From https://github.com/Zhengxinyang/LAS-Diffusion/blob/a7eb304a24dec2eb85a8d3899c73338e10435bba/network/model_utils.py#L345 - if beta_schedule == "linear": - return -torch.log(torch.special.expm1(1e-4 + 10 * t ** 2)) - elif beta_schedule == "squaredcos_cap_v2": - return -torch.log(torch.clamp((torch.cos((t + 0.008) / (1 + 0.008) * math.pi * 0.5) ** -2) - 1, min=1e-5)) - elif beta_schedule == "sigmoid": - # From https://colab.research.google.com/github/google-research/vdm/blob/main/colab/SimpleDiffusionColab.ipynb - gamma_min = -6 # -13.3 in VDM CIFAR10 experiments - gamma_max = 6 # 5.0 in VDM CIFAR10 experiments - return gamma_max + (gamma_min - gamma_max) * t - - raise NotImplementedError(f"{beta_schedule} does is not implemented for {VDMScheduler.__class__}") - def get_timesteps(self, num_steps: Optional[int] = None) -> np.ndarray: if num_steps is None: num_steps = self.config.num_train_timesteps @@ -113,11 +110,15 @@ def get_timesteps(self, num_steps: Optional[int] = None) -> np.ndarray: return timesteps.copy() def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None): - if self.config.num_train_timesteps is None: + if not self.config.num_train_timesteps: timesteps = self.get_timesteps(num_inference_steps) else: if self.config.timestep_spacing in ["linspace", "leading"]: - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, + start = 0 + stop = self.config.num_train_timesteps + timesteps = np.linspace(start, + stop - 1 if self.config.timestep_spacing == "linspace" else stop, + num_inference_steps, endpoint=self.config.timestep_spacing == "linspace")[::-1] elif self.config.timestep_spacing == "trailing": timesteps = np.arange(self.config.num_train_timesteps, @@ -171,7 +172,7 @@ def add_noise(self, timesteps: Union[torch.FloatTensor, torch.IntTensor, torch.LongTensor]) -> torch.FloatTensor: if not timesteps.is_floating_point(): - if self.config.num_train_timesteps is None: + if not self.config.num_train_timesteps: raise TypeError("Discrete timesteps require `self.config.num_train_timesteps` to be set.") timesteps = timesteps / self.config.num_train_timesteps # Normalize to [0, 1] @@ -196,18 +197,18 @@ def step(self, dtype=torch.float32 if isinstance(timestep, float) else torch.int64, device=sample.device) - if self.timesteps is None: - if not timestep.is_floating_point(): - timestep = timestep / len(self) - prev_timestep = (timestep - 1 / len(self)).clamp(0, 1) - else: - # index + 1 corresponds to t - 1 as timesteps are reversed - index = self.index_for_timestep(timestep) - prev_timestep = self.timesteps[index + 1] if index < len(self.timesteps) else 0 + if not timestep.is_floating_point(): + if not self.config.num_train_timesteps: + raise TypeError("Discrete timesteps require `self.config.num_train_timesteps` to be set.") + timestep = timestep / self.config.num_train_timesteps # Normalize to [0, 1] + prev_timestep = (timestep - 1 / len(self)).clamp(0, 1) + + if prev_timestep > timestep: + raise ValueError("`self.timesteps` must be in descending order.") # 1. Compute current and previous alpha and sigma values - log_snr = self._log_snr_cached(timestep, beta_schedule=self.config.beta_schedule) - prev_log_snr = self._log_snr_cached(prev_timestep, beta_schedule=self.config.beta_schedule) + log_snr = self.log_snr(timestep) + prev_log_snr = self.log_snr(prev_timestep) alpha, sigma = torch.sigmoid(log_snr), torch.sigmoid(-log_snr) prev_alpha, prev_sigma = torch.sigmoid(prev_log_snr), torch.sigmoid(-prev_log_snr) @@ -224,14 +225,11 @@ def step(self, raise ValueError("`prediction_type` must be either `epsilon` or `sample`.") # 3. Add noise - variance = 0 - if timestep > 0: - noise = randn_tensor(model_output.shape, - generator=generator, - device=model_output.device, - dtype=model_output.dtype) - variance = torch.sqrt(prev_sigma * c) * noise - + noise = randn_tensor(model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype) + variance = torch.sqrt(prev_sigma * c) * noise pred_prev_sample = pred_prev_sample + variance if not return_dict: From 02b71ef8e5c4a2b32cbbac56ffbc345055709f23 Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Thu, 18 Apr 2024 14:45:59 +0200 Subject: [PATCH 05/16] Refactor VDMScheduler in scheduling_vdm.py - Move the `log_snr` function from the class to a separate method. - Add type annotations to the `log_snr` method. - Remove the `index_for_timestep` method, as it is copied from another class. - Remove the normalization of timesteps in the `scale_model_input` method, as it is already done in the `log_snr` method. These changes improve code organization and maintainability. --- src/diffusers/schedulers/scheduling_vdm.py | 32 ++++++---------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index 835d3347318d..bac10663c8ce 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -77,8 +77,6 @@ def __init__(self, self.beta_end = 0.02 self.init_noise_sigma = 1.0 - self.log_snr = partial(log_snr, beta_schedule=beta_schedule) - # For linear beta schedule equivalent to torch.exp(-1e-4 - 10 * t ** 2) self.alphas_cumprod = lambda t: torch.sigmoid(self.log_snr(t)) # Equivalent to 1 - self.sigmas self.sigmas = lambda t: torch.sigmoid(-self.log_snr(t)) # Equivalent to 1 - self.alphas_cumprod @@ -96,6 +94,14 @@ def __init__(self, def __len__(self) -> int: return self.num_inference_steps or self.config.num_train_timesteps or 1000 + def log_snr(self, timesteps: torch.Tensor) -> torch.FloatTensor: + if not timesteps.is_floating_point(): + if not self.config.num_train_timesteps: + raise TypeError("Discrete timesteps require `self.config.num_train_timesteps` to be set.") + timesteps = timesteps / self.config.num_train_timesteps # Normalize to [0, 1] + + return log_snr(timesteps, beta_schedule=self.config.beta_schedule) + def get_timesteps(self, num_steps: Optional[int] = None) -> np.ndarray: if num_steps is None: num_steps = self.config.num_train_timesteps @@ -133,21 +139,6 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to timesteps += self.config.steps_offset self.timesteps = torch.from_numpy(timesteps).to(device) - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - pos = 1 if len(indices) > 1 else 0 - - return indices[pos].item() - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ @@ -171,11 +162,6 @@ def add_noise(self, noise: torch.FloatTensor, timesteps: Union[torch.FloatTensor, torch.IntTensor, torch.LongTensor]) -> torch.FloatTensor: - if not timesteps.is_floating_point(): - if not self.config.num_train_timesteps: - raise TypeError("Discrete timesteps require `self.config.num_train_timesteps` to be set.") - timesteps = timesteps / self.config.num_train_timesteps # Normalize to [0, 1] - log_snr = self.log_snr(timesteps) log_snr = log_snr.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) @@ -187,7 +173,7 @@ def add_noise(self, def step(self, model_output: torch.FloatTensor, - timestep: Union[int, float, torch.FloatTensor, torch.IntTensor, torch.LongTensor], + timestep: Union[int, float, torch.Tensor], sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True) -> Union[VDMSchedulerOutput, Tuple]: From fb004577e66b9373c1ac67569a650a11c06aff5b Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Thu, 18 Apr 2024 17:49:23 +0200 Subject: [PATCH 06/16] Refactor VDMScheduler in scheduling_vdm.py - Add `clip_sample` boolean flag to control whether to clip the predicted original sample within a specified range. - Add `thresholding` boolean flag to enable dynamic thresholding of the predicted original sample. - Add `dynamic_thresholding_ratio` to determine the percentile value for dynamic thresholding. - Add `clip_sample_range` to specify the range for clipping the predicted original sample. - Add `sample_max_value` to set the maximum value for dynamic thresholding. - Implement `_threshold_sample` method to perform dynamic thresholding on the sample. - Modify `scale_model_input` method to use the new `clip_sample` and `thresholding` flags. - Refactor the computation of predicted original and previous samples in the `__call__` method. - Add noise to the predicted previous sample if `noise_scale` is greater than 0. These changes enhance the flexibility and control over the sample generation process in VDMScheduler. --- src/diffusers/schedulers/scheduling_vdm.py | 80 ++++++++++++++++++---- 1 file changed, 65 insertions(+), 15 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index bac10663c8ce..f2b49d792411 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -69,7 +69,12 @@ class VDMScheduler(SchedulerMixin, ConfigMixin): def __init__(self, num_train_timesteps: Optional[int] = None, beta_schedule: str = "linear", + clip_sample: bool = True, prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, timestep_spacing: str = "leading", steps_offset: Union[int, float] = 0): # Hardcoded as continuous schedules in self._log_snr are fitted to these values @@ -113,7 +118,7 @@ def get_timesteps(self, num_steps: Optional[int] = None) -> np.ndarray: else: raise ValueError(f"`{self.config.timestep_spacing}` timestep spacing is not supported." "Choose one of 'linspace', 'leading' or 'trailing'.") - return timesteps.copy() + return timesteps.astype(np.float32).copy() def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None): if not self.config.num_train_timesteps: @@ -139,6 +144,40 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to timesteps += self.config.steps_offset self.timesteps = torch.from_numpy(timesteps).to(device) + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ @@ -177,6 +216,7 @@ def step(self, sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True) -> Union[VDMSchedulerOutput, Tuple]: + # From https://github.com/addtt/variational-diffusion-models/blob/7f81074dfdfc897178ad3d471458ea03e16197e8/vdm.py#L29 if isinstance(timestep, (int, float)): timestep = torch.tensor(timestep, @@ -189,9 +229,6 @@ def step(self, timestep = timestep / self.config.num_train_timesteps # Normalize to [0, 1] prev_timestep = (timestep - 1 / len(self)).clamp(0, 1) - if prev_timestep > timestep: - raise ValueError("`self.timesteps` must be in descending order.") - # 1. Compute current and previous alpha and sigma values log_snr = self.log_snr(timestep) prev_log_snr = self.log_snr(prev_timestep) @@ -199,24 +236,37 @@ def step(self, alpha, sigma = torch.sigmoid(log_snr), torch.sigmoid(-log_snr) prev_alpha, prev_sigma = torch.sigmoid(prev_log_snr), torch.sigmoid(-prev_log_snr) - # 2. Compute predicted original sample x_0 and predicted previous sample x_{t-1} - c = -torch.expm1(log_snr - prev_log_snr) + # 2. Compute predicted original sample x_0 if self.config.prediction_type == "epsilon": pred_original_sample = (sample - torch.sqrt(sigma) * model_output) / torch.sqrt(alpha) - pred_prev_sample = torch.sqrt(prev_alpha / alpha) * (sample - c * torch.sqrt(sigma) * model_output) elif self.config.prediction_type == "sample": pred_original_sample = model_output - pred_prev_sample = torch.sqrt(prev_alpha) * (sample * (1 - c) / torch.sqrt(alpha) + c * model_output) else: raise ValueError("`prediction_type` must be either `epsilon` or `sample`.") - # 3. Add noise - noise = randn_tensor(model_output.shape, - generator=generator, - device=model_output.device, - dtype=model_output.dtype) - variance = torch.sqrt(prev_sigma * c) * noise - pred_prev_sample = pred_prev_sample + variance + # 3. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp(-self.config.clip_sample_range, + self.config.clip_sample_range) + + # 4. Computed predicted previous sample x_{t-1} + c = -torch.expm1(log_snr - prev_log_snr) + if self.config.thresholding or self.config.clip_sample: + pred_prev_sample = torch.sqrt(prev_alpha) * (sample * (1 - c) / torch.sqrt(alpha) + c * pred_original_sample) + else: + pred_prev_sample = torch.sqrt(prev_alpha / alpha) * (sample - c * torch.sqrt(sigma) * model_output) + + # 5. (Maybe) add noise + noise_scale = torch.sqrt(prev_sigma * c) # Becomes 0 for prev_timestep = 0 + if noise_scale > 0: + noise = randn_tensor(model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype) + variance = noise_scale * noise + pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample,) From 487f87115b9976072537a2adbc0ff0e52cdc88bb Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Fri, 19 Apr 2024 11:44:23 +0200 Subject: [PATCH 07/16] Refactor VDMScheduler class in scheduling_vdm.py - Refactored the add_noise() method to use torch.Tensor instead of specific tensor types. - Updated the step() method to handle batched inputs. - Modified the computation of predicted original sample x_0 in the step() method. - Adjusted the addition of noise in the step() method to handle noise_scale > 0. These changes improve the flexibility and efficiency of the VDMScheduler class in the scheduling_vdm.py file. --- src/diffusers/schedulers/scheduling_vdm.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index f2b49d792411..9425255713b2 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -197,9 +197,9 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = return sample def add_noise(self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: Union[torch.FloatTensor, torch.IntTensor, torch.LongTensor]) -> torch.FloatTensor: + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor) -> torch.FloatTensor: log_snr = self.log_snr(timesteps) log_snr = log_snr.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) @@ -211,9 +211,9 @@ def add_noise(self, return noisy_samples def step(self, - model_output: torch.FloatTensor, + model_output: torch.Tensor, timestep: Union[int, float, torch.Tensor], - sample: torch.FloatTensor, + sample: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True) -> Union[VDMSchedulerOutput, Tuple]: # From https://github.com/addtt/variational-diffusion-models/blob/7f81074dfdfc897178ad3d471458ea03e16197e8/vdm.py#L29 @@ -233,12 +233,17 @@ def step(self, log_snr = self.log_snr(timestep) prev_log_snr = self.log_snr(prev_timestep) + # Allow for batched inputs + if timestep.ndim > 0: + log_snr = log_snr.view(timestep.size(0), *((1,) * (sample.ndim - 1))) + prev_log_snr = prev_log_snr.view(timestep.size(0), *((1,) * (sample.ndim - 1))) + alpha, sigma = torch.sigmoid(log_snr), torch.sigmoid(-log_snr) prev_alpha, prev_sigma = torch.sigmoid(prev_log_snr), torch.sigmoid(-prev_log_snr) # 2. Compute predicted original sample x_0 if self.config.prediction_type == "epsilon": - pred_original_sample = (sample - torch.sqrt(sigma) * model_output) / torch.sqrt(alpha) + pred_original_sample = (sample - torch.sqrt(sigma) * model_output) / torch.sqrt(alpha) # Sec. 3.4, eq. 10 elif self.config.prediction_type == "sample": pred_original_sample = model_output else: @@ -260,13 +265,12 @@ def step(self, # 5. (Maybe) add noise noise_scale = torch.sqrt(prev_sigma * c) # Becomes 0 for prev_timestep = 0 - if noise_scale > 0: + if torch.any(noise_scale > 0): noise = randn_tensor(model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype) - variance = noise_scale * noise - pred_prev_sample = pred_prev_sample + variance + pred_prev_sample += noise_scale * noise if not return_dict: return (pred_prev_sample,) From 390241117def50a91295b323d16502758bb32c38 Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Fri, 19 Apr 2024 16:39:17 +0200 Subject: [PATCH 08/16] Refactor VDMScheduler's predicted previous sample computation - Refactored the computation of the predicted previous sample in the VDMScheduler class. - Added a condition to include the prediction type "sample" in the computation. - The computation now considers thresholding, clipping, and the prediction type "sample" to calculate the predicted previous sample. - This change improves the accuracy of the predicted previous sample in certain scenarios. --- src/diffusers/schedulers/scheduling_vdm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index 9425255713b2..679a191f5237 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -258,7 +258,7 @@ def step(self, # 4. Computed predicted previous sample x_{t-1} c = -torch.expm1(log_snr - prev_log_snr) - if self.config.thresholding or self.config.clip_sample: + if self.config.thresholding or self.config.clip_sample or self.config.prediction_type == "sample": pred_prev_sample = torch.sqrt(prev_alpha) * (sample * (1 - c) / torch.sqrt(alpha) + c * pred_original_sample) else: pred_prev_sample = torch.sqrt(prev_alpha / alpha) * (sample - c * torch.sqrt(sigma) * model_output) From 6199b57119cbaf410875c2584c74207c50c776b4 Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Mon, 22 Apr 2024 08:57:48 +0200 Subject: [PATCH 09/16] Refactor VDMScheduler class and add docstrings This commit refactors the VDMScheduler class in the scheduling_vdm.py file. The class now includes docstrings that provide detailed explanations of the class and its methods. The log_snr method calculates the logarithm of the signal-to-noise ratio for given timesteps. The get_timesteps method generates an array of timesteps based on the configured spacing method. The set_timesteps method sets the discrete timesteps used for the diffusion chain. The add_noise method adds noise to the original samples according to the noise schedule and specified timesteps. The step method performs a single step of the diffusion process, computing the previous sample and optionally the predicted original sample. These changes improve code readability and maintainability. --- src/diffusers/schedulers/scheduling_vdm.py | 129 ++++++++++++++++++++- 1 file changed, 125 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index 679a191f5237..b3c40561f02c 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -29,10 +29,26 @@ def log_snr(t: torch.FloatTensor, beta_schedule: str) -> torch.FloatTensor: + """ + Calculates the logarithm of the signal-to-noise ratio (SNR) for given time steps `t` under a specified beta schedule. + + The function supports multiple beta schedules, which are key to controlling the noise levels in diffusion models. + It returns the logarithmic SNR, which is used to compute model parameters at different diffusion steps. + + Args: + t (torch.FloatTensor): Tensor of time steps, normalized between [0, 1]. + beta_schedule (str): The beta schedule type. Supported types include 'linear', 'squaredcos_cap_v2', and 'sigmoid'. + + Returns: + torch.FloatTensor: The log SNR values corresponding to the input time steps under the given beta schedule. + + Raises: + ValueError: If `t` is outside the range [0, 1] or if the beta_schedule is unsupported. + """ if t.min() < 0 or t.max() > 1: raise ValueError("`t` must be in range [0, 1].") - # From https://github.com/Zhengxinyang/LAS-Diffusion/blob/a7eb304a24dec2eb85a8d3899c73338e10435bba/network/model_utils.py#L345 + # From https://github.com/Zhengxinyang/LAS-Diffusion/blob/main/network/model_utils.py#L345 if beta_schedule == "linear": return -torch.log(torch.special.expm1(1e-4 + 10 * t ** 2)) elif beta_schedule == "squaredcos_cap_v2": @@ -65,15 +81,52 @@ class VDMSchedulerOutput(BaseOutput): class VDMScheduler(SchedulerMixin, ConfigMixin): + """ + Implements the discrete and continuous scheduler as presented in `Variational Diffusion Models` [1]. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to None, *optional*): + The number of diffusion steps to train the model. If not provided, assumes continuous formulation. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + + References: + [1] "Variational Diffusion Models" by Diederik P. Kingma, Tim Salimans, Ben Poole and Jonathan Ho, ArXiv, 2021. + """ + @register_to_config def __init__(self, num_train_timesteps: Optional[int] = None, beta_schedule: str = "linear", clip_sample: bool = True, + clip_sample_range: float = 1.0, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, - clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", steps_offset: Union[int, float] = 0): @@ -97,9 +150,22 @@ def __init__(self, self.betas = 1 - self.alphas def __len__(self) -> int: + """Returns the number of inference steps or the number of training timesteps or 1000, whichever is set.""" return self.num_inference_steps or self.config.num_train_timesteps or 1000 def log_snr(self, timesteps: torch.Tensor) -> torch.FloatTensor: + """ + Computes the logarithm of the signal-to-noise ratio for given timesteps using the configured beta schedule. + + Args: + timesteps (torch.Tensor): Tensor of timesteps, which can be either normalized to [0, 1] range or discrete. + + Returns: + torch.FloatTensor: The computed log SNR values for the given timesteps. + + Raises: + TypeError: If discrete timesteps are used without setting `num_train_timesteps` in the configuration. + """ if not timesteps.is_floating_point(): if not self.config.num_train_timesteps: raise TypeError("Discrete timesteps require `self.config.num_train_timesteps` to be set.") @@ -107,20 +173,47 @@ def log_snr(self, timesteps: torch.Tensor) -> torch.FloatTensor: return log_snr(timesteps, beta_schedule=self.config.beta_schedule) + def get_timesteps(self, num_steps: Optional[int] = None) -> np.ndarray: + """ + Generates an array of timesteps based on the configured spacing method, either evenly spaced or leading/trailing. + + Args: + num_steps (int, optional): The number of timesteps to generate. Defaults to `num_train_timesteps`. + + Returns: + np.ndarray: An array of timesteps, distributed according to the `timestep_spacing` configuration. + + Raises: + ValueError: If an unsupported `timestep_spacing` configuration is provided. + """ if num_steps is None: num_steps = self.config.num_train_timesteps if self.config.timestep_spacing in ["linspace", "leading"]: timesteps = np.linspace(0, 1, num_steps, endpoint=self.config.timestep_spacing == "linspace")[::-1] - elif self.config.time_spacing == "trailing": + elif self.config.timestep_spacing == "trailing": timesteps = np.arange(1, 0, -1 / num_steps) - 1 / num_steps else: raise ValueError(f"`{self.config.timestep_spacing}` timestep spacing is not supported." "Choose one of 'linspace', 'leading' or 'trailing'.") return timesteps.astype(np.float32).copy() + def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + + Raises: + ValueError: If an unsupported `timestep_spacing` configuration is provided. + """ if not self.config.num_train_timesteps: timesteps = self.get_timesteps(num_inference_steps) else: @@ -200,7 +293,20 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.FloatTensor: + """ + Adds noise to the original samples according to the noise schedule and the specified timesteps. + This method calculates the noisy samples by combining the original samples with Gaussian noise + scaled according to the time-dependent noise levels dictated by the signal-to-noise ratio. + + Args: + original_samples (torch.Tensor): The original samples from the data distribution before noise is added. + noise (torch.Tensor): Gaussian noise to be added to the samples. + timesteps (torch.Tensor): Timesteps at which the samples are processed. + + Returns: + torch.FloatTensor: The noisy samples after adding scaled Gaussian noise according to the SNR. + """ log_snr = self.log_snr(timesteps) log_snr = log_snr.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) @@ -216,7 +322,22 @@ def step(self, sample: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True) -> Union[VDMSchedulerOutput, Tuple]: - # From https://github.com/addtt/variational-diffusion-models/blob/7f81074dfdfc897178ad3d471458ea03e16197e8/vdm.py#L29 + """ + Performs a single step of the diffusion process, computing the previous sample and optionally the predicted + original sample based on the model output and current timestep. + + Args: + model_output (torch.Tensor): The output from the diffusion model, typically noise predictions. + timestep (int, float, torch.Tensor): Current timestep in the diffusion process. + sample (torch.Tensor): The current sample at timestep `t`. + generator (torch.Generator, *optional*): Generator for random numbers, used for adding noise. + return_dict (bool): If True, returns a `VDMSchedulerOutput` object; otherwise, returns a tuple. + + Returns: + VDMSchedulerOutput or Tuple: Depending on `return_dict`, returns either a data class containing + the previous sample and predicted original sample, or just the previous sample as a tuple. + """ + # Based on https://github.com/addtt/variational-diffusion-models/blob/main/vdm.py#L29 if isinstance(timestep, (int, float)): timestep = torch.tensor(timestep, From 7c4a4a725e995f2c5a8fb28bdddb85e062e5d471 Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Mon, 22 Apr 2024 09:06:39 +0200 Subject: [PATCH 10/16] Update copyright information in scheduling_vdm.py --- src/diffusers/schedulers/scheduling_vdm.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index b3c40561f02c..6a1d0c0ee64c 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -1,4 +1,4 @@ -# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim - import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -32,8 +30,7 @@ def log_snr(t: torch.FloatTensor, beta_schedule: str) -> torch.FloatTensor: """ Calculates the logarithm of the signal-to-noise ratio (SNR) for given time steps `t` under a specified beta schedule. - The function supports multiple beta schedules, which are key to controlling the noise levels in diffusion models. - It returns the logarithmic SNR, which is used to compute model parameters at different diffusion steps. + See appendix K of the [Variational Diffusion Models](https://arxiv.org/abs/2107.00630) paper for more details. Args: t (torch.FloatTensor): Tensor of time steps, normalized between [0, 1]. @@ -130,12 +127,12 @@ def __init__(self, sample_max_value: float = 1.0, timestep_spacing: str = "leading", steps_offset: Union[int, float] = 0): - # Hardcoded as continuous schedules in self._log_snr are fitted to these values + # Hardcoded as continuous schedules in `log_snr` are fitted to these values self.beta_start = 1e-4 self.beta_end = 0.02 self.init_noise_sigma = 1.0 - # For linear beta schedule equivalent to torch.exp(-1e-4 - 10 * t ** 2) + # For linear beta schedule, equivalent to torch.exp(-1e-4 - 10 * t ** 2) self.alphas_cumprod = lambda t: torch.sigmoid(self.log_snr(t)) # Equivalent to 1 - self.sigmas self.sigmas = lambda t: torch.sigmoid(-self.log_snr(t)) # Equivalent to 1 - self.alphas_cumprod @@ -176,7 +173,7 @@ def log_snr(self, timesteps: torch.Tensor) -> torch.FloatTensor: def get_timesteps(self, num_steps: Optional[int] = None) -> np.ndarray: """ - Generates an array of timesteps based on the configured spacing method, either evenly spaced or leading/trailing. + Generates timesteps in the range [0, 1] for the continuous formulation. Args: num_steps (int, optional): The number of timesteps to generate. Defaults to `num_train_timesteps`. @@ -202,7 +199,7 @@ def get_timesteps(self, num_steps: Optional[int] = None) -> np.ndarray: def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None): """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Sets the discrete or continuous timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): @@ -308,6 +305,7 @@ def add_noise(self, torch.FloatTensor: The noisy samples after adding scaled Gaussian noise according to the SNR. """ log_snr = self.log_snr(timesteps) + # Reshape from (1,) to (B, ...) where B is the batch size and ... are the spatial dimensions log_snr = log_snr.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) sqrt_alpha_prod = torch.sqrt(torch.sigmoid(log_snr)) From 533ae9bf99113166acb85282d82630b8e299a946 Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Mon, 22 Apr 2024 11:13:58 +0200 Subject: [PATCH 11/16] Add VDMScheduler to diffusers/__init__.py and implement VDMScheduler in diffusers/schedulers/scheduling_vdm.py - Added "VDMScheduler" to the list of imported schedulers in diffusers/__init__.py. - Implemented the VDMScheduler class in diffusers/schedulers/scheduling_vdm.py. - Added the log_snr() function to calculate the logarithm of the signal-to-noise ratio (SNR) for given time steps and beta schedule. - Created the VDMSchedulerOutput class as the output for the scheduler's step function. - Implemented the VDMScheduler class with necessary methods and attributes. - Added test cases for VDMScheduler in tests/schedulers/test_scheduler_vdm.py. - Updated the test cases in tests/schedulers/test_schedulers.py to include VDMScheduler. --- src/diffusers/__init__.py | 2 + src/diffusers/schedulers/scheduling_vdm.py | 53 +++++----- tests/schedulers/test_scheduler_vdm.py | 109 +++++++++++++++++++++ tests/schedulers/test_schedulers.py | 3 +- 4 files changed, 137 insertions(+), 30 deletions(-) create mode 100644 tests/schedulers/test_scheduler_vdm.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5d6761663938..b7b624e6744c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -167,6 +167,7 @@ "UnCLIPScheduler", "UniPCMultistepScheduler", "VQDiffusionScheduler", + "VDMScheduler", ] ) _import_structure["training_utils"] = ["EMAModel"] @@ -562,6 +563,7 @@ UnCLIPScheduler, UniPCMultistepScheduler, VQDiffusionScheduler, + VDMScheduler, ) from .training_utils import EMAModel diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index 6a1d0c0ee64c..d1bef735dc69 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -26,18 +26,18 @@ from .scheduling_utils import SchedulerMixin -def log_snr(t: torch.FloatTensor, beta_schedule: str) -> torch.FloatTensor: +def log_snr(t: torch.Tensor, beta_schedule: str) -> torch.Tensor: """ Calculates the logarithm of the signal-to-noise ratio (SNR) for given time steps `t` under a specified beta schedule. See appendix K of the [Variational Diffusion Models](https://arxiv.org/abs/2107.00630) paper for more details. Args: - t (torch.FloatTensor): Tensor of time steps, normalized between [0, 1]. + t (torch.Tensor): Tensor of time steps, normalized between [0, 1]. beta_schedule (str): The beta schedule type. Supported types include 'linear', 'squaredcos_cap_v2', and 'sigmoid'. Returns: - torch.FloatTensor: The log SNR values corresponding to the input time steps under the given beta schedule. + torch.Tensor: The log SNR values corresponding to the input time steps under the given beta schedule. Raises: ValueError: If `t` is outside the range [0, 1] or if the beta_schedule is unsupported. @@ -65,16 +65,16 @@ class VDMSchedulerOutput(BaseOutput): Output class for the scheduler's `step` function output. Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ - prev_sample: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None class VDMScheduler(SchedulerMixin, ConfigMixin): @@ -89,15 +89,14 @@ class VDMScheduler(SchedulerMixin, ConfigMixin): The number of diffusion steps to train the model. If not provided, assumes continuous formulation. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + `linear`, `squaredcos_cap_v2` or `sigmoid`. clip_sample (`bool`, defaults to `True`): Clip the predicted sample for numerical stability. clip_sample_range (`float`, defaults to 1.0): The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://imagen.research.google/video/paper.pdf) paper). + or `sample` (directly predicts the noisy sample`). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. @@ -137,12 +136,10 @@ def __init__(self, self.sigmas = lambda t: torch.sigmoid(-self.log_snr(t)) # Equivalent to 1 - self.alphas_cumprod self.num_inference_steps = None - self.timesteps = None + self.timesteps = torch.from_numpy(self.get_timesteps(len(self))) if num_train_timesteps: - # TODO: Might not be exact - self.timesteps = torch.from_numpy(self.get_timesteps(len(self))) alphas_cumprod = self.alphas_cumprod(torch.flip(self.timesteps, dims=(0,))) - alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] + alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] # TODO: Might not be exact self.alphas = torch.cat([alphas_cumprod[:1], alphas]) self.betas = 1 - self.alphas @@ -150,7 +147,7 @@ def __len__(self) -> int: """Returns the number of inference steps or the number of training timesteps or 1000, whichever is set.""" return self.num_inference_steps or self.config.num_train_timesteps or 1000 - def log_snr(self, timesteps: torch.Tensor) -> torch.FloatTensor: + def log_snr(self, timesteps: torch.Tensor) -> torch.Tensor: """ Computes the logarithm of the signal-to-noise ratio for given timesteps using the configured beta schedule. @@ -158,7 +155,7 @@ def log_snr(self, timesteps: torch.Tensor) -> torch.FloatTensor: timesteps (torch.Tensor): Tensor of timesteps, which can be either normalized to [0, 1] range or discrete. Returns: - torch.FloatTensor: The computed log SNR values for the given timesteps. + torch.Tensor: The computed log SNR values for the given timesteps. Raises: TypeError: If discrete timesteps are used without setting `num_train_timesteps` in the configuration. @@ -170,7 +167,6 @@ def log_snr(self, timesteps: torch.Tensor) -> torch.FloatTensor: return log_snr(timesteps, beta_schedule=self.config.beta_schedule) - def get_timesteps(self, num_steps: Optional[int] = None) -> np.ndarray: """ Generates timesteps in the range [0, 1] for the continuous formulation. @@ -185,7 +181,7 @@ def get_timesteps(self, num_steps: Optional[int] = None) -> np.ndarray: ValueError: If an unsupported `timestep_spacing` configuration is provided. """ if num_steps is None: - num_steps = self.config.num_train_timesteps + num_steps = len(self) if self.config.timestep_spacing in ["linspace", "leading"]: timesteps = np.linspace(0, 1, num_steps, endpoint=self.config.timestep_spacing == "linspace")[::-1] @@ -196,7 +192,6 @@ def get_timesteps(self, num_steps: Optional[int] = None) -> np.ndarray: "Choose one of 'linspace', 'leading' or 'trailing'.") return timesteps.astype(np.float32).copy() - def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None): """ Sets the discrete or continuous timesteps used for the diffusion chain (to be run before inference). @@ -235,7 +230,7 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to self.timesteps = torch.from_numpy(timesteps).to(device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample - def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by @@ -269,19 +264,19 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: - sample (`torch.FloatTensor`): + sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: - `torch.FloatTensor`: + `torch.Tensor`: A scaled input sample. """ return sample @@ -289,7 +284,7 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, - timesteps: torch.Tensor) -> torch.FloatTensor: + timesteps: torch.Tensor) -> torch.Tensor: """ Adds noise to the original samples according to the noise schedule and the specified timesteps. @@ -302,14 +297,14 @@ def add_noise(self, timesteps (torch.Tensor): Timesteps at which the samples are processed. Returns: - torch.FloatTensor: The noisy samples after adding scaled Gaussian noise according to the SNR. + torch.Tensor: The noisy samples after adding scaled Gaussian noise according to the SNR. """ - log_snr = self.log_snr(timesteps) + gamma = self.log_snr(timesteps).to(original_samples.device) # Reshape from (1,) to (B, ...) where B is the batch size and ... are the spatial dimensions - log_snr = log_snr.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) + gamma = gamma.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) - sqrt_alpha_prod = torch.sqrt(torch.sigmoid(log_snr)) - sqrt_one_minus_alpha_prod = torch.sqrt(torch.sigmoid(-log_snr)) # sqrt(sigma) + sqrt_alpha_prod = torch.sqrt(torch.sigmoid(gamma)) + sqrt_one_minus_alpha_prod = torch.sqrt(torch.sigmoid(-gamma)) # sqrt(sigma) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/tests/schedulers/test_scheduler_vdm.py b/tests/schedulers/test_scheduler_vdm.py new file mode 100644 index 000000000000..4d06c30d7ee5 --- /dev/null +++ b/tests/schedulers/test_scheduler_vdm.py @@ -0,0 +1,109 @@ +import torch + +from diffusers import VDMScheduler + +from .test_schedulers import SchedulerCommonTest + + +class VDMSchedulerTest(SchedulerCommonTest): + scheduler_classes = (VDMScheduler,) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_schedule": "linear", + "clip_sample": True, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [None, 1, 5, 100, 1000]: + self.check_over_configs(time_step=0.0 if timesteps is None else 0, num_train_timesteps=timesteps) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2", "sigmoid"]: + self.check_over_configs(beta_schedule=schedule) + + def test_clip_sample(self): + for clip_sample in [True, False]: + self.check_over_configs(clip_sample=clip_sample) + + def test_thresholding(self): + self.check_over_configs(thresholding=False) + for threshold in [0.5, 1.0, 2.0]: + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs( + thresholding=True, + prediction_type=prediction_type, + sample_max_value=threshold, + ) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_time_indices(self): + for t in [0, 500, 999]: + self.check_over_forward(time_step=t) # Discrete + self.check_over_forward(time_step=t / 1000) # Continuous + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + num_trained_timesteps = len(scheduler) + + model = self.dummy_model() + sample = self.dummy_sample_deter + print(sample.abs().sum()) + generator = torch.manual_seed(0) + + for t in reversed(range(num_trained_timesteps)): + # 1. predict noise residual + residual = model(sample, t) + if t == len(scheduler) - 1: + print(residual.abs().sum()) + + # 2. predict previous mean of sample x_t-1 + sample = scheduler.step(residual, t, sample, generator=generator).prev_sample + if t == len(scheduler) - 1: + print(sample.abs().sum()) + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 256.4699) < 1e-2, f" expected result sum 256.4699, but get {result_sum}" + assert abs(result_mean.item() - 0.3339) < 1e-3, f" expected result mean 0.3339, but get {result_mean}" + + def test_full_loop_with_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + num_trained_timesteps = len(scheduler) + t_start = num_trained_timesteps - 2 + + model = self.dummy_model() + sample = self.dummy_sample_deter + generator = torch.manual_seed(0) + + # add noise + noise = self.dummy_noise_deter + timesteps = scheduler.timesteps[t_start:] + sample = scheduler.add_noise(sample, noise, timesteps[:1]) + + for t in timesteps: + # 1. predict noise residual + residual = model(sample, t) + + # 2. predict previous mean of sample x_t-1 + sample = scheduler.step(residual, t, sample, generator=generator).prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 387.6469) < 1e-2, f" expected result sum 387.6469, but get {result_sum}" + assert abs(result_mean.item() - 0.5051) < 1e-3, f" expected result mean 0.5051, but get {result_mean}" diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index fc7f22d2a8e5..67621d4214f2 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -37,6 +37,7 @@ LMSDiscreteScheduler, UniPCMultistepScheduler, VQDiffusionScheduler, + VDMScheduler, ) from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin @@ -750,7 +751,7 @@ def test_deprecated_kwargs(self): def test_trained_betas(self): for scheduler_class in self.scheduler_classes: - if scheduler_class in (VQDiffusionScheduler, CMStochasticIterativeScheduler): + if scheduler_class in (VQDiffusionScheduler, CMStochasticIterativeScheduler, VDMScheduler): continue scheduler_config = self.get_scheduler_config() From a05af00329fac89fd709ac21712213200616cb44 Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Mon, 22 Apr 2024 11:16:17 +0200 Subject: [PATCH 12/16] Refactor diffusers module and schedulers - Remove "VQDiffusionScheduler" from the import structure in "__init__.py" of the "diffusers" module. - Remove "VQDiffusionScheduler" from the import structure in "__init__.py" of the "schedulers" sub-module. - Remove "VQDiffusionScheduler" import from "scheduling_vq_diffusion.py" in the "schedulers" sub-module. - Update the "VDMScheduler" class in "scheduling_vdm.py" in the "schedulers" sub-module: - Add type hints and remove unused imports. - Refactor the "log_snr" function for better readability. - Refactor the "__init__" method for better readability. - Refactor the "set_timesteps" method for better readability. - Refactor the "add_noise" method for better readability. - Refactor the "step" method for better readability. - Update the test cases in "test_schedulers.py" to reflect the changes. These changes refactor the import structure and improve the readability of the code in the "diffusers" module and the "schedulers" sub-module. --- src/diffusers/__init__.py | 4 +- src/diffusers/schedulers/__init__.py | 4 +- src/diffusers/schedulers/scheduling_vdm.py | 118 ++++++++++++--------- tests/schedulers/test_schedulers.py | 2 +- 4 files changed, 70 insertions(+), 58 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b7b624e6744c..48f8d59d98a1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -166,8 +166,8 @@ "TCDScheduler", "UnCLIPScheduler", "UniPCMultistepScheduler", - "VQDiffusionScheduler", "VDMScheduler", + "VQDiffusionScheduler", ] ) _import_structure["training_utils"] = ["EMAModel"] @@ -562,8 +562,8 @@ TCDScheduler, UnCLIPScheduler, UniPCMultistepScheduler, - VQDiffusionScheduler, VDMScheduler, + VQDiffusionScheduler, ) from .training_utils import EMAModel diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 2cb401b86538..aebce2bc531a 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -69,8 +69,8 @@ _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"] _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"] - _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"] _import_structure["scheduling_vdm"] = ["VDMScheduler"] + _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"] try: if not is_flax_available(): @@ -165,8 +165,8 @@ from .scheduling_unclip import UnCLIPScheduler from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin - from .scheduling_vq_diffusion import VQDiffusionScheduler from .scheduling_vdm import VDMScheduler + from .scheduling_vq_diffusion import VQDiffusionScheduler try: if not is_flax_available(): diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index d1bef735dc69..52bebadc51f3 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -14,8 +14,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union -from functools import partial +from typing import Optional, Tuple, Union import numpy as np import torch @@ -28,13 +27,15 @@ def log_snr(t: torch.Tensor, beta_schedule: str) -> torch.Tensor: """ - Calculates the logarithm of the signal-to-noise ratio (SNR) for given time steps `t` under a specified beta schedule. + Calculates the logarithm of the signal-to-noise ratio (SNR) for given time steps `t` under a specified beta + schedule. See appendix K of the [Variational Diffusion Models](https://arxiv.org/abs/2107.00630) paper for more details. Args: t (torch.Tensor): Tensor of time steps, normalized between [0, 1]. - beta_schedule (str): The beta schedule type. Supported types include 'linear', 'squaredcos_cap_v2', and 'sigmoid'. + beta_schedule (str): + The beta schedule type. Supported types include 'linear', 'squaredcos_cap_v2', and 'sigmoid'. Returns: torch.Tensor: The log SNR values corresponding to the input time steps under the given beta schedule. @@ -47,7 +48,7 @@ def log_snr(t: torch.Tensor, beta_schedule: str) -> torch.Tensor: # From https://github.com/Zhengxinyang/LAS-Diffusion/blob/main/network/model_utils.py#L345 if beta_schedule == "linear": - return -torch.log(torch.special.expm1(1e-4 + 10 * t ** 2)) + return -torch.log(torch.special.expm1(1e-4 + 10 * t**2)) elif beta_schedule == "squaredcos_cap_v2": return -torch.log(torch.clamp((torch.cos((t + 0.008) / (1 + 0.008) * math.pi * 0.5) ** -2) - 1, min=1e-5)) elif beta_schedule == "sigmoid": @@ -115,17 +116,19 @@ class VDMScheduler(SchedulerMixin, ConfigMixin): """ @register_to_config - def __init__(self, - num_train_timesteps: Optional[int] = None, - beta_schedule: str = "linear", - clip_sample: bool = True, - clip_sample_range: float = 1.0, - prediction_type: str = "epsilon", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - timestep_spacing: str = "leading", - steps_offset: Union[int, float] = 0): + def __init__( + self, + num_train_timesteps: Optional[int] = None, + beta_schedule: str = "linear", + clip_sample: bool = True, + clip_sample_range: float = 1.0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: Union[int, float] = 0, + ): # Hardcoded as continuous schedules in `log_snr` are fitted to these values self.beta_start = 1e-4 self.beta_end = 0.02 @@ -183,13 +186,14 @@ def get_timesteps(self, num_steps: Optional[int] = None) -> np.ndarray: if num_steps is None: num_steps = len(self) if self.config.timestep_spacing in ["linspace", "leading"]: - timesteps = np.linspace(0, 1, num_steps, - endpoint=self.config.timestep_spacing == "linspace")[::-1] + timesteps = np.linspace(0, 1, num_steps, endpoint=self.config.timestep_spacing == "linspace")[::-1] elif self.config.timestep_spacing == "trailing": timesteps = np.arange(1, 0, -1 / num_steps) - 1 / num_steps else: - raise ValueError(f"`{self.config.timestep_spacing}` timestep spacing is not supported." - "Choose one of 'linspace', 'leading' or 'trailing'.") + raise ValueError( + f"`{self.config.timestep_spacing}` timestep spacing is not supported." + "Choose one of 'linspace', 'leading' or 'trailing'." + ) return timesteps.astype(np.float32).copy() def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None): @@ -212,17 +216,24 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to if self.config.timestep_spacing in ["linspace", "leading"]: start = 0 stop = self.config.num_train_timesteps - timesteps = np.linspace(start, - stop - 1 if self.config.timestep_spacing == "linspace" else stop, - num_inference_steps, - endpoint=self.config.timestep_spacing == "linspace")[::-1] + timesteps = np.linspace( + start, + stop - 1 if self.config.timestep_spacing == "linspace" else stop, + num_inference_steps, + endpoint=self.config.timestep_spacing == "linspace", + )[::-1] elif self.config.timestep_spacing == "trailing": - timesteps = np.arange(self.config.num_train_timesteps, - 0, - -self.config.num_train_timesteps / num_inference_steps) - 1 + timesteps = ( + np.arange( + self.config.num_train_timesteps, 0, -self.config.num_train_timesteps / num_inference_steps + ) + - 1 + ) else: - raise ValueError(f"`{self.config.timestep_spacing}` timestep spacing is not supported." - "Choose one of 'linspace', 'leading' or 'trailing'.") + raise ValueError( + f"`{self.config.timestep_spacing}` timestep spacing is not supported." + "Choose one of 'linspace', 'leading' or 'trailing'." + ) timesteps = timesteps.round().astype(np.int64).copy() self.num_inference_steps = num_inference_steps @@ -281,15 +292,12 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None """ return sample - def add_noise(self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.Tensor) -> torch.Tensor: + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: """ Adds noise to the original samples according to the noise schedule and the specified timesteps. - This method calculates the noisy samples by combining the original samples with Gaussian noise - scaled according to the time-dependent noise levels dictated by the signal-to-noise ratio. + This method calculates the noisy samples by combining the original samples with Gaussian noise scaled according + to the time-dependent noise levels dictated by the signal-to-noise ratio. Args: original_samples (torch.Tensor): The original samples from the data distribution before noise is added. @@ -309,12 +317,14 @@ def add_noise(self, noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples - def step(self, - model_output: torch.Tensor, - timestep: Union[int, float, torch.Tensor], - sample: torch.Tensor, - generator: Optional[torch.Generator] = None, - return_dict: bool = True) -> Union[VDMSchedulerOutput, Tuple]: + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, float, torch.Tensor], + sample: torch.Tensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[VDMSchedulerOutput, Tuple]: """ Performs a single step of the diffusion process, computing the previous sample and optionally the predicted original sample based on the model output and current timestep. @@ -327,15 +337,15 @@ def step(self, return_dict (bool): If True, returns a `VDMSchedulerOutput` object; otherwise, returns a tuple. Returns: - VDMSchedulerOutput or Tuple: Depending on `return_dict`, returns either a data class containing - the previous sample and predicted original sample, or just the previous sample as a tuple. + VDMSchedulerOutput or Tuple: Depending on `return_dict`, returns either a data class containing the + previous sample and predicted original sample, or just the previous sample as a tuple. """ # Based on https://github.com/addtt/variational-diffusion-models/blob/main/vdm.py#L29 if isinstance(timestep, (int, float)): - timestep = torch.tensor(timestep, - dtype=torch.float32 if isinstance(timestep, float) else torch.int64, - device=sample.device) + timestep = torch.tensor( + timestep, dtype=torch.float32 if isinstance(timestep, float) else torch.int64, device=sample.device + ) if not timestep.is_floating_point(): if not self.config.num_train_timesteps: @@ -367,23 +377,25 @@ def step(self, if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: - pred_original_sample = pred_original_sample.clamp(-self.config.clip_sample_range, - self.config.clip_sample_range) + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) # 4. Computed predicted previous sample x_{t-1} c = -torch.expm1(log_snr - prev_log_snr) if self.config.thresholding or self.config.clip_sample or self.config.prediction_type == "sample": - pred_prev_sample = torch.sqrt(prev_alpha) * (sample * (1 - c) / torch.sqrt(alpha) + c * pred_original_sample) + pred_prev_sample = torch.sqrt(prev_alpha) * ( + sample * (1 - c) / torch.sqrt(alpha) + c * pred_original_sample + ) else: pred_prev_sample = torch.sqrt(prev_alpha / alpha) * (sample - c * torch.sqrt(sigma) * model_output) # 5. (Maybe) add noise noise_scale = torch.sqrt(prev_sigma * c) # Becomes 0 for prev_timestep = 0 if torch.any(noise_scale > 0): - noise = randn_tensor(model_output.shape, - generator=generator, - device=model_output.device, - dtype=model_output.dtype) + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) pred_prev_sample += noise_scale * noise if not return_dict: diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index 67621d4214f2..0c0abc8f7205 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -36,8 +36,8 @@ IPNDMScheduler, LMSDiscreteScheduler, UniPCMultistepScheduler, - VQDiffusionScheduler, VDMScheduler, + VQDiffusionScheduler, ) from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin From 880c7f87ecc805e9c94cb726d45bb0368b6890d7 Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Mon, 22 Apr 2024 11:47:10 +0200 Subject: [PATCH 13/16] Add VDMScheduler and VDMSchedulerOutput to the API documentation. This commit adds the VDMScheduler and VDMSchedulerOutput classes to the API documentation. The VDMScheduler class is a part of the Variational Diffusion Models (VDM) library, which introduces diffusion-based generative models for image density estimation. The VDMSchedulerOutput class is also included in the documentation. These additions provide users with information on how to use these classes and their functionalities. The commit also includes the necessary copyright and license information. --- docs/source/en/api/schedulers/vdm.md | 41 ++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 docs/source/en/api/schedulers/vdm.md diff --git a/docs/source/en/api/schedulers/vdm.md b/docs/source/en/api/schedulers/vdm.md new file mode 100644 index 000000000000..fa9c8786720f --- /dev/null +++ b/docs/source/en/api/schedulers/vdm.md @@ -0,0 +1,41 @@ + + +# VDMScheduler + +[Variational Diffusion Models](https://arxiv.org/abs/2107.00630) (VDM) by Diederik P. Kingma, Tim Salimans, Ben +Poole and Jonathan Ho introduces a family of diffusion-based generative models that achieve state-of-the-art +log-likelihoods on standard image density estimation benchmarks by formulating diffusion as a continuous-time problem +in terms of the signal-to-noise ratio. + +The abstract from the paper is: + +*Diffusion-based generative models have demonstrated a capacity for perceptually impressive synthesis, but can they +also be great likelihood-based models? We answer this in the affirmative, and introduce a family of diffusion-based +generative models that obtain state-of-the-art likelihoods on standard image density estimation benchmarks. Unlike +other diffusion-based models, our method allows for efficient optimization of the noise schedule jointly with the +rest of the model. We show that the variational lower bound (VLB) simplifies to a remarkably short expression in terms +of the signal-to-noise ratio of the diffused data, thereby improving our theoretical understanding of this model class. +Using this insight, we prove an equivalence between several models proposed in the literature. In addition, we show that +the continuous-time VLB is invariant to the noise schedule, except for the signal-to-noise ratio at its endpoints. This +enables us to learn a noise schedule that minimizes the variance of the resulting VLB estimator, leading to faster +optimization. Combining these advances with architectural improvements, we obtain state-of-the-art likelihoods on image +density estimation benchmarks, outperforming autoregressive models that have dominated these benchmarks for many years, +with often significantly faster optimization. In addition, we show how to use the model as part of a bits-back +compression scheme, and demonstrate lossless compression rates close to the theoretical optimum. Code is available at +[this https URL](https://github.com/google-research/vdm).* + +## VDMScheduler +[[autodoc]] VDMScheduler + +## VDMSchedulerOutput +[[autodoc]] schedulers.scheduling_vdm.VDMSchedulerOutput From a5eedfd2f46c8ff2d142872bf7abbaeea5b96b4e Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Tue, 11 Jun 2024 11:41:25 +0200 Subject: [PATCH 14/16] Refactor VDMScheduler class in scheduling_vdm.py - Removed the unused `__len__` method. - Removed the `add_noise` method as it is no longer used. - Added a new method `get_velocity` to calculate the velocity based on the noise and sample. - Added support for a new prediction type called `v_prediction`. These changes improve the code structure and add functionality to the VDMScheduler class. --- src/diffusers/schedulers/scheduling_vdm.py | 77 +++++++++++++--------- 1 file changed, 47 insertions(+), 30 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index 52bebadc51f3..61a0f6e82005 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -146,10 +146,6 @@ def __init__( self.alphas = torch.cat([alphas_cumprod[:1], alphas]) self.betas = 1 - self.alphas - def __len__(self) -> int: - """Returns the number of inference steps or the number of training timesteps or 1000, whichever is set.""" - return self.num_inference_steps or self.config.num_train_timesteps or 1000 - def log_snr(self, timesteps: torch.Tensor) -> torch.Tensor: """ Computes the logarithm of the signal-to-noise ratio for given timesteps using the configured beta schedule. @@ -292,31 +288,6 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None """ return sample - def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - """ - Adds noise to the original samples according to the noise schedule and the specified timesteps. - - This method calculates the noisy samples by combining the original samples with Gaussian noise scaled according - to the time-dependent noise levels dictated by the signal-to-noise ratio. - - Args: - original_samples (torch.Tensor): The original samples from the data distribution before noise is added. - noise (torch.Tensor): Gaussian noise to be added to the samples. - timesteps (torch.Tensor): Timesteps at which the samples are processed. - - Returns: - torch.Tensor: The noisy samples after adding scaled Gaussian noise according to the SNR. - """ - gamma = self.log_snr(timesteps).to(original_samples.device) - # Reshape from (1,) to (B, ...) where B is the batch size and ... are the spatial dimensions - gamma = gamma.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) - - sqrt_alpha_prod = torch.sqrt(torch.sigmoid(gamma)) - sqrt_one_minus_alpha_prod = torch.sqrt(torch.sigmoid(-gamma)) # sqrt(sigma) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - def step( self, model_output: torch.Tensor, @@ -370,8 +341,14 @@ def step( pred_original_sample = (sample - torch.sqrt(sigma) * model_output) / torch.sqrt(alpha) # Sec. 3.4, eq. 10 elif self.config.prediction_type == "sample": pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = torch.sqrt(alpha) * sample + torch.sqrt(sigma) * model_output else: - raise ValueError("`prediction_type` must be either `epsilon` or `sample`.") + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + f" `v_prediction` for the {self.__class__.__name__}." + ) + # 3. Clip or threshold "predicted x_0" if self.config.thresholding: @@ -402,3 +379,43 @@ def step( return (pred_prev_sample,) return VDMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Adds noise to the original samples according to the noise schedule and the specified timesteps. + + This method calculates the noisy samples by combining the original samples with Gaussian noise scaled according + to the time-dependent noise levels dictated by the signal-to-noise ratio. + + Args: + original_samples (torch.Tensor): The original samples from the data distribution before noise is added. + noise (torch.Tensor): Gaussian noise to be added to the samples. + timesteps (torch.Tensor): Timesteps at which the samples are processed. + + Returns: + torch.Tensor: The noisy samples after adding scaled Gaussian noise according to the SNR. + """ + gamma = self.log_snr(timesteps).to(original_samples.device) + # Reshape from (1,) to (B, ...) where B is the batch size and ... are the spatial dimensions + gamma = gamma.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) + + sqrt_alpha_prod = torch.sqrt(torch.sigmoid(gamma)) + sqrt_one_minus_alpha_prod = torch.sqrt(torch.sigmoid(-gamma)) # sqrt(sigma) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + gamma = self.log_snr(timesteps).to(original_samples.device) + # Reshape from (1,) to (B, ...) where B is the batch size and ... are the spatial dimensions + gamma = gamma.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) + + sqrt_alpha_prod = torch.sqrt(torch.sigmoid(gamma)) + sqrt_one_minus_alpha_prod = torch.sqrt(torch.sigmoid(-gamma)) # sqrt(sigma) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self) -> int: + """Returns the number of inference steps or the number of training timesteps or 1000, whichever is set.""" + return self.num_inference_steps or self.config.num_train_timesteps or 1000 From 9fc5b585017750a7db347944a7d7f86b478e6e68 Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Tue, 11 Jun 2024 17:38:58 +0200 Subject: [PATCH 15/16] Refactor VDMScheduler in scheduling_vdm.py - Refactored the `get_velocity` method to correctly handle the device of the input tensors. - Removed unnecessary comment and reshaping code in `get_velocity` method. - Modified the condition in the `if` statement in the `__init__` method to use the `!=` operator instead of `==` for `prediction_type`. These changes improve the code readability and ensure correct device handling for tensors in the VDMScheduler class. --- src/diffusers/schedulers/scheduling_vdm.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index 61a0f6e82005..137995ce04bf 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -360,7 +360,7 @@ def step( # 4. Computed predicted previous sample x_{t-1} c = -torch.expm1(log_snr - prev_log_snr) - if self.config.thresholding or self.config.clip_sample or self.config.prediction_type == "sample": + if self.config.thresholding or self.config.clip_sample or self.config.prediction_type != "epsilon": pred_prev_sample = torch.sqrt(prev_alpha) * ( sample * (1 - c) / torch.sqrt(alpha) + c * pred_original_sample ) @@ -396,7 +396,6 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste torch.Tensor: The noisy samples after adding scaled Gaussian noise according to the SNR. """ gamma = self.log_snr(timesteps).to(original_samples.device) - # Reshape from (1,) to (B, ...) where B is the batch size and ... are the spatial dimensions gamma = gamma.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) sqrt_alpha_prod = torch.sqrt(torch.sigmoid(gamma)) @@ -405,10 +404,9 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples - def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: - gamma = self.log_snr(timesteps).to(original_samples.device) - # Reshape from (1,) to (B, ...) where B is the batch size and ... are the spatial dimensions - gamma = gamma.view(timesteps.size(0), *((1,) * (original_samples.ndim - 1))) + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + gamma = self.log_snr(timesteps).to(sample.device) + gamma = gamma.view(timesteps.size(0), *((1,) * (sample.ndim - 1))) sqrt_alpha_prod = torch.sqrt(torch.sigmoid(gamma)) sqrt_one_minus_alpha_prod = torch.sqrt(torch.sigmoid(-gamma)) # sqrt(sigma) From 766f142411e37cf6c92d1f7753ba91a5d615450e Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Wed, 12 Jun 2024 10:20:37 +0200 Subject: [PATCH 16/16] Fix prediction calculation in VDMScheduler The commit fixes a bug in the calculation of `pred_original_sample` in the `VDMScheduler` class. Previously, the calculation used addition, but it should use subtraction. This bug affected the `v_prediction` prediction type. The fix ensures that the calculation is correct and consistent with the intended behavior of the `v_prediction` type. --- src/diffusers/schedulers/scheduling_vdm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_vdm.py b/src/diffusers/schedulers/scheduling_vdm.py index 137995ce04bf..907f5a037ff9 100644 --- a/src/diffusers/schedulers/scheduling_vdm.py +++ b/src/diffusers/schedulers/scheduling_vdm.py @@ -342,7 +342,7 @@ def step( elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": - pred_original_sample = torch.sqrt(alpha) * sample + torch.sqrt(sigma) * model_output + pred_original_sample = torch.sqrt(alpha) * sample - torch.sqrt(sigma) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"