Skip to content

Commit e4546fd

Browse files
a-r-r-o-wsayakpaul
andauthored
[docs] Add missing copied from statements in TCD Scheduler (#7360)
* add missing copied from statements in tcd scheduler * update docstring --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent d44e31a commit e4546fd

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/diffusers/schedulers/scheduling_tcd.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] =
307307
The input sample.
308308
timestep (`int`, *optional*):
309309
The current timestep in the diffusion chain.
310+
310311
Returns:
311312
`torch.FloatTensor`:
312313
A scaled input sample.
@@ -364,7 +365,7 @@ def set_timesteps(
364365
device: Union[str, torch.device] = None,
365366
original_inference_steps: Optional[int] = None,
366367
timesteps: Optional[List[int]] = None,
367-
strength: int = 1.0,
368+
strength: float = 1.0,
368369
):
369370
"""
370371
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -384,6 +385,8 @@ def set_timesteps(
384385
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
385386
timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
386387
schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
388+
strength (`float`, *optional*, defaults to 1.0):
389+
Used to determine the number of timesteps used for inference when using img2img, inpaint, etc.
387390
"""
388391
# 0. Check inputs
389392
if num_inference_steps is None and timesteps is None:
@@ -624,14 +627,18 @@ def step(
624627

625628
return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample)
626629

630+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
627631
def add_noise(
628632
self,
629633
original_samples: torch.FloatTensor,
630634
noise: torch.FloatTensor,
631635
timesteps: torch.IntTensor,
632636
) -> torch.FloatTensor:
633637
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
634-
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
638+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
639+
# for the subsequent add_noise calls
640+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
641+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
635642
timesteps = timesteps.to(original_samples.device)
636643

637644
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
@@ -647,11 +654,13 @@ def add_noise(
647654
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
648655
return noisy_samples
649656

657+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
650658
def get_velocity(
651659
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
652660
) -> torch.FloatTensor:
653661
# Make sure alphas_cumprod and timestep have same device and dtype as sample
654-
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
662+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
663+
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
655664
timesteps = timesteps.to(sample.device)
656665

657666
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
@@ -670,6 +679,7 @@ def get_velocity(
670679
def __len__(self):
671680
return self.config.num_train_timesteps
672681

682+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
673683
def previous_timestep(self, timestep):
674684
if self.custom_timesteps:
675685
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]

0 commit comments

Comments
 (0)