@@ -75,7 +75,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
75
75
prediction_type (`str`, default `epsilon`, optional):
76
76
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
77
77
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
78
- https://imagen.research.google/video/paper.pdf)
78
+ https://imagen.research.google/video/paper.pdf).
79
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
80
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
81
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
82
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
79
83
"""
80
84
81
85
_compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -90,6 +94,7 @@ def __init__(
90
94
beta_schedule : str = "linear" ,
91
95
trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
92
96
prediction_type : str = "epsilon" ,
97
+ use_karras_sigmas : Optional [bool ] = False ,
93
98
):
94
99
if trained_betas is not None :
95
100
self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -111,6 +116,7 @@ def __init__(
111
116
112
117
# set all values
113
118
self .set_timesteps (num_train_timesteps , None , num_train_timesteps )
119
+ self .use_karras_sigmas = use_karras_sigmas
114
120
115
121
def index_for_timestep (self , timestep , schedule_timesteps = None ):
116
122
if schedule_timesteps is None :
@@ -165,7 +171,13 @@ def set_timesteps(
165
171
timesteps = np .linspace (0 , num_train_timesteps - 1 , num_inference_steps , dtype = float )[::- 1 ].copy ()
166
172
167
173
sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
174
+ log_sigmas = np .log (sigmas )
168
175
sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
176
+
177
+ if self .use_karras_sigmas :
178
+ sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
179
+ timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
180
+
169
181
sigmas = np .concatenate ([sigmas , [0.0 ]]).astype (np .float32 )
170
182
sigmas = torch .from_numpy (sigmas ).to (device = device )
171
183
self .sigmas = torch .cat ([sigmas [:1 ], sigmas [1 :- 1 ].repeat_interleave (2 ), sigmas [- 1 :]])
@@ -186,6 +198,44 @@ def set_timesteps(
186
198
self .prev_derivative = None
187
199
self .dt = None
188
200
201
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
202
+ def _sigma_to_t (self , sigma , log_sigmas ):
203
+ # get log sigma
204
+ log_sigma = np .log (sigma )
205
+
206
+ # get distribution
207
+ dists = log_sigma - log_sigmas [:, np .newaxis ]
208
+
209
+ # get sigmas range
210
+ low_idx = np .cumsum ((dists >= 0 ), axis = 0 ).argmax (axis = 0 ).clip (max = log_sigmas .shape [0 ] - 2 )
211
+ high_idx = low_idx + 1
212
+
213
+ low = log_sigmas [low_idx ]
214
+ high = log_sigmas [high_idx ]
215
+
216
+ # interpolate sigmas
217
+ w = (low - log_sigma ) / (low - high )
218
+ w = np .clip (w , 0 , 1 )
219
+
220
+ # transform interpolation to time range
221
+ t = (1 - w ) * low_idx + w * high_idx
222
+ t = t .reshape (sigma .shape )
223
+ return t
224
+
225
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
226
+ def _convert_to_karras (self , in_sigmas : torch .FloatTensor , num_inference_steps ) -> torch .FloatTensor :
227
+ """Constructs the noise schedule of Karras et al. (2022)."""
228
+
229
+ sigma_min : float = in_sigmas [- 1 ].item ()
230
+ sigma_max : float = in_sigmas [0 ].item ()
231
+
232
+ rho = 7.0 # 7.0 is the value used in the paper
233
+ ramp = np .linspace (0 , 1 , num_inference_steps )
234
+ min_inv_rho = sigma_min ** (1 / rho )
235
+ max_inv_rho = sigma_max ** (1 / rho )
236
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
237
+ return sigmas
238
+
189
239
@property
190
240
def state_in_first_order (self ):
191
241
return self .dt is None
0 commit comments