Skip to content

Commit 11f527a

Browse files
Add Karras sigmas to HeunDiscreteScheduler (#3160)
* Add karras pattern to discrete heun scheduler * Add integration test * Fix failing CI on pytorch test on M1 (mps) --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2c04e58 commit 11f527a

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

src/diffusers/schedulers/scheduling_heun_discrete.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
7575
prediction_type (`str`, default `epsilon`, optional):
7676
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
7777
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
78-
https://imagen.research.google/video/paper.pdf)
78+
https://imagen.research.google/video/paper.pdf).
79+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
80+
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
81+
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
82+
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
7983
"""
8084

8185
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -90,6 +94,7 @@ def __init__(
9094
beta_schedule: str = "linear",
9195
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
9296
prediction_type: str = "epsilon",
97+
use_karras_sigmas: Optional[bool] = False,
9398
):
9499
if trained_betas is not None:
95100
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -111,6 +116,7 @@ def __init__(
111116

112117
# set all values
113118
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
119+
self.use_karras_sigmas = use_karras_sigmas
114120

115121
def index_for_timestep(self, timestep, schedule_timesteps=None):
116122
if schedule_timesteps is None:
@@ -165,7 +171,13 @@ def set_timesteps(
165171
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
166172

167173
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
174+
log_sigmas = np.log(sigmas)
168175
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
176+
177+
if self.use_karras_sigmas:
178+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
179+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
180+
169181
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
170182
sigmas = torch.from_numpy(sigmas).to(device=device)
171183
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
@@ -186,6 +198,44 @@ def set_timesteps(
186198
self.prev_derivative = None
187199
self.dt = None
188200

201+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
202+
def _sigma_to_t(self, sigma, log_sigmas):
203+
# get log sigma
204+
log_sigma = np.log(sigma)
205+
206+
# get distribution
207+
dists = log_sigma - log_sigmas[:, np.newaxis]
208+
209+
# get sigmas range
210+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
211+
high_idx = low_idx + 1
212+
213+
low = log_sigmas[low_idx]
214+
high = log_sigmas[high_idx]
215+
216+
# interpolate sigmas
217+
w = (low - log_sigma) / (low - high)
218+
w = np.clip(w, 0, 1)
219+
220+
# transform interpolation to time range
221+
t = (1 - w) * low_idx + w * high_idx
222+
t = t.reshape(sigma.shape)
223+
return t
224+
225+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
226+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
227+
"""Constructs the noise schedule of Karras et al. (2022)."""
228+
229+
sigma_min: float = in_sigmas[-1].item()
230+
sigma_max: float = in_sigmas[0].item()
231+
232+
rho = 7.0 # 7.0 is the value used in the paper
233+
ramp = np.linspace(0, 1, num_inference_steps)
234+
min_inv_rho = sigma_min ** (1 / rho)
235+
max_inv_rho = sigma_max ** (1 / rho)
236+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
237+
return sigmas
238+
189239
@property
190240
def state_in_first_order(self):
191241
return self.dt is None

tests/schedulers/test_scheduler_heun.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,28 @@ def test_full_loop_device(self):
129129
# CUDA
130130
assert abs(result_sum.item() - 0.1233) < 1e-2
131131
assert abs(result_mean.item() - 0.0002) < 1e-3
132+
133+
def test_full_loop_device_karras_sigmas(self):
134+
scheduler_class = self.scheduler_classes[0]
135+
scheduler_config = self.get_scheduler_config()
136+
scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True)
137+
138+
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
139+
140+
model = self.dummy_model()
141+
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
142+
sample = sample.to(torch_device)
143+
144+
for t in scheduler.timesteps:
145+
sample = scheduler.scale_model_input(sample, t)
146+
147+
model_output = model(sample, t)
148+
149+
output = scheduler.step(model_output, t, sample)
150+
sample = output.prev_sample
151+
152+
result_sum = torch.sum(torch.abs(sample))
153+
result_mean = torch.mean(torch.abs(sample))
154+
155+
assert abs(result_sum.item() - 0.00015) < 1e-2
156+
assert abs(result_mean.item() - 1.9869554535034695e-07) < 1e-2

0 commit comments

Comments
 (0)