@@ -162,6 +162,7 @@ def __init__(
162
162
self .init_noise_sigma = 1.0
163
163
164
164
# setable values
165
+ self .custom_timesteps = False
165
166
self .num_inference_steps = None
166
167
self .timesteps = torch .from_numpy (np .arange (0 , num_train_timesteps )[::- 1 ].copy ())
167
168
@@ -191,14 +192,31 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] =
191
192
"""
192
193
return sample
193
194
194
- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
195
+ def set_timesteps (
196
+ self ,
197
+ num_inference_steps : int ,
198
+ device : Union [str , torch .device ] = None ,
199
+ custom_timesteps : Optional [List [int ]] = None ,
200
+ ):
195
201
"""
196
202
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
197
203
198
204
Args:
199
205
num_inference_steps (`int`):
200
206
the number of diffusion steps used when generating samples with a pre-trained model.
207
+ device (`str` or `torch.device`, optional):
208
+ the device to which the timesteps are moved to.
209
+ custom_timesteps (`List[int]`, optional):
210
+ custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
211
+ timestep spacing strategy of equal spacing between timesteps is used.
212
+
201
213
"""
214
+ if custom_timesteps is not None :
215
+ num_inference_steps = len (custom_timesteps )
216
+
217
+ for i in range (1 , len (custom_timesteps )):
218
+ if custom_timesteps [i ] >= custom_timesteps [i - 1 ]:
219
+ raise ValueError ("`custom_timesteps` must be in descending order." )
202
220
203
221
if num_inference_steps > self .config .num_train_timesteps :
204
222
raise ValueError (
@@ -209,13 +227,19 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
209
227
210
228
self .num_inference_steps = num_inference_steps
211
229
212
- step_ratio = self .config .num_train_timesteps // self .num_inference_steps
213
- timesteps = (np .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ].copy ().astype (np .int64 )
230
+ if custom_timesteps is None :
231
+ step_ratio = self .config .num_train_timesteps // self .num_inference_steps
232
+ timesteps = (np .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ].copy ().astype (np .int64 )
233
+ self .custom_timesteps = False
234
+ else :
235
+ timesteps = np .array (custom_timesteps , dtype = np .int64 )
236
+ self .custom_timesteps = True
237
+
214
238
self .timesteps = torch .from_numpy (timesteps ).to (device )
215
239
216
240
def _get_variance (self , t , predicted_variance = None , variance_type = None ):
217
- num_inference_steps = self .num_inference_steps if self . num_inference_steps else self . config . num_train_timesteps
218
- prev_t = t - self . config . num_train_timesteps // num_inference_steps
241
+ prev_t = self .previous_timestep ( t )
242
+
219
243
alpha_prod_t = self .alphas_cumprod [t ]
220
244
alpha_prod_t_prev = self .alphas_cumprod [prev_t ] if prev_t >= 0 else self .one
221
245
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
@@ -314,8 +338,8 @@ def step(
314
338
315
339
"""
316
340
t = timestep
317
- num_inference_steps = self . num_inference_steps if self . num_inference_steps else self . config . num_train_timesteps
318
- prev_t = timestep - self .config . num_train_timesteps // num_inference_steps
341
+
342
+ prev_t = self .previous_timestep ( t )
319
343
320
344
if model_output .shape [1 ] == sample .shape [1 ] * 2 and self .variance_type in ["learned" , "learned_range" ]:
321
345
model_output , predicted_variance = torch .split (model_output , sample .shape [1 ], dim = 1 )
@@ -428,3 +452,18 @@ def get_velocity(
428
452
429
453
def __len__ (self ):
430
454
return self .config .num_train_timesteps
455
+
456
+ def previous_timestep (self , timestep ):
457
+ if self .custom_timesteps :
458
+ index = (self .timesteps == timestep ).nonzero (as_tuple = True )[0 ][0 ]
459
+ if index == self .timesteps .shape [0 ] - 1 :
460
+ prev_t = torch .tensor (- 1 )
461
+ else :
462
+ prev_t = self .timesteps [index + 1 ]
463
+ else :
464
+ num_inference_steps = (
465
+ self .num_inference_steps if self .num_inference_steps else self .config .num_train_timesteps
466
+ )
467
+ prev_t = timestep - self .config .num_train_timesteps // num_inference_steps
468
+
469
+ return prev_t
0 commit comments