Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 6 additions & 20 deletions src/diffusers/schedulers/scheduling_sasolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
batch_size, channels, *remaining_dims = sample.shape

if dtype not in (torch.float32, torch.float64):
sample = (
sample.float()
) # upcast for quantile calculation, and clamp not implemented for cpu half
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half

# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
Expand All @@ -342,9 +340,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = (
torch.clamp(sample, -s, s) / s
) # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"

sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
Expand All @@ -360,11 +356,7 @@ def _sigma_to_t(self, sigma, log_sigmas):
dists = log_sigma - log_sigmas[:, np.newaxis]

# get sigmas range
low_idx = (
np.cumsum((dists >= 0), axis=0)
.argmax(axis=0)
.clip(max=log_sigmas.shape[0] - 2)
)
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1

low = log_sigmas[low_idx]
Expand All @@ -387,9 +379,7 @@ def _sigma_to_alpha_sigma_t(self, sigma):
return alpha_t, sigma_t

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(
self, in_sigmas: torch.FloatTensor, num_inference_steps
) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""

# Hack to make sure that other schedulers which copy this function don't break
Expand Down Expand Up @@ -1200,9 +1190,7 @@ def add_noise(
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(
device=original_samples.device, dtype=original_samples.dtype
)
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)

sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
Expand All @@ -1215,9 +1203,7 @@ def add_noise(
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

noisy_samples = (
sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

def __len__(self):
Expand Down