@@ -162,6 +162,7 @@ def __init__(
162
162
self .init_noise_sigma = 1.0
163
163
164
164
# setable values
165
+ self .custom_timesteps = False
165
166
self .num_inference_steps = None
166
167
self .timesteps = torch .from_numpy (np .arange (0 , num_train_timesteps )[::- 1 ].copy ())
167
168
@@ -191,31 +192,62 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] =
191
192
"""
192
193
return sample
193
194
194
- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
195
+ def set_timesteps (
196
+ self ,
197
+ num_inference_steps : Optional [int ] = None ,
198
+ device : Union [str , torch .device ] = None ,
199
+ timesteps : Optional [List [int ]] = None ,
200
+ ):
195
201
"""
196
202
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
197
203
198
204
Args:
199
- num_inference_steps (`int`):
200
- the number of diffusion steps used when generating samples with a pre-trained model.
205
+ num_inference_steps (`Optional[int]`):
206
+ the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
207
+ `timesteps` must be `None`.
208
+ device (`str` or `torch.device`, optional):
209
+ the device to which the timesteps are moved to.
210
+ custom_timesteps (`List[int]`, optional):
211
+ custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
212
+ timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
213
+ must be `None`.
214
+
201
215
"""
216
+ if num_inference_steps is not None and timesteps is not None :
217
+ raise ValueError ("Can only pass one of `num_inference_steps` or `custom_timesteps`." )
218
+
219
+ if timesteps is not None :
220
+ for i in range (1 , len (timesteps )):
221
+ if timesteps [i ] >= timesteps [i - 1 ]:
222
+ raise ValueError ("`custom_timesteps` must be in descending order." )
223
+
224
+ if timesteps [0 ] >= self .config .num_train_timesteps :
225
+ raise ValueError (
226
+ f"`timesteps` must start before `self.config.train_timesteps`:"
227
+ f" { self .config .num_train_timesteps } ."
228
+ )
229
+
230
+ timesteps = np .array (timesteps , dtype = np .int64 )
231
+ self .custom_timesteps = True
232
+ else :
233
+ if num_inference_steps > self .config .num_train_timesteps :
234
+ raise ValueError (
235
+ f"`num_inference_steps`: { num_inference_steps } cannot be larger than `self.config.train_timesteps`:"
236
+ f" { self .config .num_train_timesteps } as the unet model trained with this scheduler can only handle"
237
+ f" maximal { self .config .num_train_timesteps } timesteps."
238
+ )
202
239
203
- if num_inference_steps > self .config .num_train_timesteps :
204
- raise ValueError (
205
- f"`num_inference_steps`: { num_inference_steps } cannot be larger than `self.config.train_timesteps`:"
206
- f" { self .config .num_train_timesteps } as the unet model trained with this scheduler can only handle"
207
- f" maximal { self .config .num_train_timesteps } timesteps."
208
- )
240
+ self .num_inference_steps = num_inference_steps
209
241
210
- self .num_inference_steps = num_inference_steps
242
+ step_ratio = self .config .num_train_timesteps // self .num_inference_steps
243
+ timesteps = (np .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ].copy ().astype (np .int64 )
244
+ self .custom_timesteps = False
211
245
212
- step_ratio = self .config .num_train_timesteps // self .num_inference_steps
213
- timesteps = (np .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ].copy ().astype (np .int64 )
214
246
self .timesteps = torch .from_numpy (timesteps ).to (device )
215
247
216
248
def _get_variance (self , t , predicted_variance = None , variance_type = None ):
217
- num_inference_steps = self .num_inference_steps if self . num_inference_steps else self . config . num_train_timesteps
218
- prev_t = t - self . config . num_train_timesteps // num_inference_steps
249
+ prev_t = self .previous_timestep ( t )
250
+
219
251
alpha_prod_t = self .alphas_cumprod [t ]
220
252
alpha_prod_t_prev = self .alphas_cumprod [prev_t ] if prev_t >= 0 else self .one
221
253
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
@@ -314,8 +346,8 @@ def step(
314
346
315
347
"""
316
348
t = timestep
317
- num_inference_steps = self . num_inference_steps if self . num_inference_steps else self . config . num_train_timesteps
318
- prev_t = timestep - self .config . num_train_timesteps // num_inference_steps
349
+
350
+ prev_t = self .previous_timestep ( t )
319
351
320
352
if model_output .shape [1 ] == sample .shape [1 ] * 2 and self .variance_type in ["learned" , "learned_range" ]:
321
353
model_output , predicted_variance = torch .split (model_output , sample .shape [1 ], dim = 1 )
@@ -428,3 +460,18 @@ def get_velocity(
428
460
429
461
def __len__ (self ):
430
462
return self .config .num_train_timesteps
463
+
464
+ def previous_timestep (self , timestep ):
465
+ if self .custom_timesteps :
466
+ index = (self .timesteps == timestep ).nonzero (as_tuple = True )[0 ][0 ]
467
+ if index == self .timesteps .shape [0 ] - 1 :
468
+ prev_t = torch .tensor (- 1 )
469
+ else :
470
+ prev_t = self .timesteps [index + 1 ]
471
+ else :
472
+ num_inference_steps = (
473
+ self .num_inference_steps if self .num_inference_steps else self .config .num_train_timesteps
474
+ )
475
+ prev_t = timestep - self .config .num_train_timesteps // num_inference_steps
476
+
477
+ return prev_t
0 commit comments