diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index d6252904fd9a..df84dd6fd65d 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -103,6 +103,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): interpolation_type (`str`, default `"linear"`, optional): interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of [`"linear"`, `"log_linear"`]. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -118,6 +122,7 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", interpolation_type: str = "linear", + use_karras_sigmas: Optional[bool] = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -149,6 +154,7 @@ def __init__( timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.is_scale_input_called = False + self.use_karras_sigmas = use_karras_sigmas def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] @@ -187,6 +193,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) if self.config.interpolation_type == "linear": sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) @@ -198,6 +205,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic " 'linear' or 'log_linear'" ) + if self.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): @@ -206,6 +217,43 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic else: self.timesteps = torch.from_numpy(timesteps).to(device=device) + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 + def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, self.num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + def step( self, model_output: torch.FloatTensor, diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py index 4d521b0075e1..aa46ef31885a 100644 --- a/tests/schedulers/test_scheduler_euler.py +++ b/tests/schedulers/test_scheduler_euler.py @@ -117,3 +117,30 @@ def test_full_loop_device(self): assert abs(result_sum.item() - 10.0807) < 1e-2 assert abs(result_mean.item() - 0.0131) < 1e-3 + + def test_full_loop_device_karras_sigmas(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + generator = torch.manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 124.52299499511719) < 1e-2 + assert abs(result_mean.item() - 0.16213932633399963) < 1e-3