@@ -307,6 +307,7 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] =
307307 The input sample.
308308 timestep (`int`, *optional*):
309309 The current timestep in the diffusion chain.
310+
310311 Returns:
311312 `torch.FloatTensor`:
312313 A scaled input sample.
@@ -364,7 +365,7 @@ def set_timesteps(
364365 device : Union [str , torch .device ] = None ,
365366 original_inference_steps : Optional [int ] = None ,
366367 timesteps : Optional [List [int ]] = None ,
367- strength : int = 1.0 ,
368+ strength : float = 1.0 ,
368369 ):
369370 """
370371 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -384,6 +385,8 @@ def set_timesteps(
384385 Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
385386 timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
386387 schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
388+ strength (`float`, *optional*, defaults to 1.0):
389+ Used to determine the number of timesteps used for inference when using img2img, inpaint, etc.
387390 """
388391 # 0. Check inputs
389392 if num_inference_steps is None and timesteps is None :
@@ -624,14 +627,18 @@ def step(
624627
625628 return TCDSchedulerOutput (prev_sample = prev_sample , pred_noised_sample = pred_noised_sample )
626629
630+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
627631 def add_noise (
628632 self ,
629633 original_samples : torch .FloatTensor ,
630634 noise : torch .FloatTensor ,
631635 timesteps : torch .IntTensor ,
632636 ) -> torch .FloatTensor :
633637 # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
634- alphas_cumprod = self .alphas_cumprod .to (device = original_samples .device , dtype = original_samples .dtype )
638+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
639+ # for the subsequent add_noise calls
640+ self .alphas_cumprod = self .alphas_cumprod .to (device = original_samples .device )
641+ alphas_cumprod = self .alphas_cumprod .to (dtype = original_samples .dtype )
635642 timesteps = timesteps .to (original_samples .device )
636643
637644 sqrt_alpha_prod = alphas_cumprod [timesteps ] ** 0.5
@@ -647,11 +654,13 @@ def add_noise(
647654 noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
648655 return noisy_samples
649656
657+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
650658 def get_velocity (
651659 self , sample : torch .FloatTensor , noise : torch .FloatTensor , timesteps : torch .IntTensor
652660 ) -> torch .FloatTensor :
653661 # Make sure alphas_cumprod and timestep have same device and dtype as sample
654- alphas_cumprod = self .alphas_cumprod .to (device = sample .device , dtype = sample .dtype )
662+ self .alphas_cumprod = self .alphas_cumprod .to (device = sample .device )
663+ alphas_cumprod = self .alphas_cumprod .to (dtype = sample .dtype )
655664 timesteps = timesteps .to (sample .device )
656665
657666 sqrt_alpha_prod = alphas_cumprod [timesteps ] ** 0.5
@@ -670,6 +679,7 @@ def get_velocity(
670679 def __len__ (self ):
671680 return self .config .num_train_timesteps
672681
682+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
673683 def previous_timestep (self , timestep ):
674684 if self .custom_timesteps :
675685 index = (self .timesteps == timestep ).nonzero (as_tuple = True )[0 ][0 ]
0 commit comments