diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index e09db049ce99..25f559fcd299 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -328,9 +328,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): - sample = ( - sample.float() - ) # upcast for quantile calculation, and clamp not implemented for cpu half + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) @@ -342,9 +340,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 - sample = ( - torch.clamp(sample, -s, s) / s - ) # "we threshold xt0 to the range [-s, s] and then divide by s" + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) @@ -360,11 +356,7 @@ def _sigma_to_t(self, sigma, log_sigmas): dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range - low_idx = ( - np.cumsum((dists >= 0), axis=0) - .argmax(axis=0) - .clip(max=log_sigmas.shape[0] - 2) - ) + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] @@ -387,9 +379,7 @@ def _sigma_to_alpha_sigma_t(self, sigma): return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras( - self, in_sigmas: torch.FloatTensor, num_inference_steps - ) -> torch.FloatTensor: + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break @@ -1200,9 +1190,7 @@ def add_noise( timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to( - device=original_samples.device, dtype=original_samples.dtype - ) + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 @@ -1215,9 +1203,7 @@ def add_noise( while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - noisy_samples = ( - sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - ) + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self):