From f2d1ec61ba9a9aee1d0363ee8f42cc6a45bf28c4 Mon Sep 17 00:00:00 2001 From: MaxWe00 Date: Sat, 3 Jun 2023 18:06:20 +0200 Subject: [PATCH 01/14] Implement option for rescaling betas to zero terminal SNR --- src/diffusers/schedulers/scheduling_ddim.py | 44 ++++++++++++++++++ .../schedulers/scheduling_ddim_inverse.py | 44 ++++++++++++++++++ src/diffusers/schedulers/scheduling_ddpm.py | 44 ++++++++++++++++++ .../schedulers/scheduling_deis_multistep.py | 42 ++++++++++++++++- .../scheduling_dpmsolver_multistep.py | 44 ++++++++++++++++++ .../scheduling_dpmsolver_multistep_inverse.py | 44 ++++++++++++++++++ .../schedulers/scheduling_dpmsolver_sde.py | 43 ++++++++++++++++++ .../scheduling_dpmsolver_singlestep.py | 41 ++++++++++++++++- .../scheduling_euler_ancestral_discrete.py | 44 +++++++++++++++++- .../schedulers/scheduling_euler_discrete.py | 43 ++++++++++++++++++ .../schedulers/scheduling_heun_discrete.py | 43 ++++++++++++++++++ .../scheduling_k_dpm_2_ancestral_discrete.py | 43 ++++++++++++++++++ .../schedulers/scheduling_k_dpm_2_discrete.py | 44 ++++++++++++++++++ .../schedulers/scheduling_lms_discrete.py | 43 ++++++++++++++++++ src/diffusers/schedulers/scheduling_pndm.py | 45 ++++++++++++++++++- .../schedulers/scheduling_repaint.py | 44 +++++++++++++++++- src/diffusers/schedulers/scheduling_unclip.py | 43 ++++++++++++++++++ .../schedulers/scheduling_unipc_multistep.py | 40 +++++++++++++++++ 18 files changed, 773 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 6b62d8893482..76443449dea5 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -76,6 +76,41 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf + (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DDIMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising @@ -122,6 +157,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -143,6 +182,7 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -168,6 +208,10 @@ def __init__( # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index 2c9fc036a027..338448ebb1ad 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -75,6 +75,41 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf + (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): """ DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`]. @@ -111,6 +146,10 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ order = 1 @@ -128,6 +167,7 @@ def __init__( steps_offset: int = 0, prediction_type: str = "epsilon", clip_sample_range: float = 1.0, + rescale_betas_zero_snr: bool = False, **kwargs, ): if kwargs.get("set_alpha_to_one", None) is not None: @@ -161,6 +201,10 @@ def __init__( # or whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1] + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 5d24766d68c7..412483386c2e 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -73,6 +73,41 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf + (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DDPMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and @@ -114,6 +149,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -134,6 +173,7 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -158,6 +198,10 @@ def __init__( self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 8ea001a882d0..a37c2fe5a1e5 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -55,6 +55,38 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_bar_sqrt): + """ + Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf + (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): """ DEIS (https://arxiv.org/abs/2204.13902) is a fast high order solver for diffusion ODEs. We slightly modify the @@ -103,7 +135,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10. - + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -125,6 +160,7 @@ def __init__( algorithm_type: str = "deis", solver_type: str = "logrho", lower_order_final: bool = True, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -148,6 +184,10 @@ def __init__( self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alpha_t) + # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index e72b1bdc23b5..60a7221fc15e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -55,6 +55,41 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf + (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with @@ -134,6 +169,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -158,6 +197,7 @@ def __init__( use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -181,6 +221,10 @@ def __init__( self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index b424ebbff262..d263ed540a22 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -55,6 +55,41 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf + (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): """ DPMSolverMultistepInverseScheduler is the reverse scheduler of [`DPMSolverMultistepScheduler`]. @@ -123,6 +158,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -147,6 +186,7 @@ def __init__( use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -173,6 +213,10 @@ def __init__( # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index ae9229981152..76f63a4749ae 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -105,6 +105,40 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): """ Implements Stochastic Sampler (Algorithm 2) from Karras et al. (2022). Based on the original k-diffusion @@ -133,6 +167,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. noise_sampler_seed (`int`, *optional*, defaults to `None`): The random seed to use for the noise sampler. If `None`, a random seed will be generated. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -149,6 +187,7 @@ def __init__( prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -168,6 +207,10 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.use_karras_sigmas = use_karras_sigmas diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 7fa8eabb5a15..327809a3c12e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -58,6 +58,37 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_bar_sqrt): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with @@ -132,7 +163,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. - + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -157,6 +191,7 @@ def __init__( use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -180,6 +215,10 @@ def __init__( self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alpha_t) + # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 6b08e9bfc207..70130791d23a 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -76,6 +76,40 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: @@ -99,7 +133,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) - + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -114,6 +151,7 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -133,6 +171,10 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 7237128cbf07..aaa1682e4767 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -76,6 +76,40 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original @@ -107,6 +141,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -123,6 +161,7 @@ def __init__( prediction_type: str = "epsilon", interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -142,6 +181,10 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 100e2012ea20..d8b10c6a09cc 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -52,6 +52,40 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original @@ -78,6 +112,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -93,6 +131,7 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -112,6 +151,10 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.use_karras_sigmas = use_karras_sigmas diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 2fa0431e1292..c2c1b70f85e5 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -53,6 +53,40 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: @@ -78,6 +112,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -92,6 +130,7 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -111,6 +150,10 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index bb80c4a54bfe..3f7d7b9535e7 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -52,6 +52,41 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf + (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): """ Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: @@ -77,6 +112,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -91,6 +130,7 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -110,6 +150,10 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 0656475c3093..3d0a3adbe013 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -74,6 +74,40 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by @@ -102,6 +136,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -117,6 +155,7 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -136,6 +175,10 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 01c02a21bbfc..9a32320b7385 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -54,6 +54,41 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf + (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class PNDMScheduler(SchedulerMixin, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, @@ -89,7 +124,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -107,6 +145,7 @@ def __init__( set_alpha_to_one: bool = False, prediction_type: str = "epsilon", steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -128,6 +167,10 @@ def __init__( self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index f2f97b38f3d3..c262554c7c40 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -72,6 +72,40 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class RePaintScheduler(SchedulerMixin, ConfigMixin): """ RePaint is a schedule for DDPM inpainting inside a given mask. @@ -100,7 +134,10 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ order = 1 @@ -115,6 +152,7 @@ def __init__( eta: float = 0.0, trained_betas: Optional[np.ndarray] = None, clip_sample: bool = True, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -141,6 +179,10 @@ def __init__( self.final_alpha_cumprod = torch.tensor(1.0) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py index d44edcb1812a..43c9cd0668bc 100644 --- a/src/diffusers/schedulers/scheduling_unclip.py +++ b/src/diffusers/schedulers/scheduling_unclip.py @@ -73,6 +73,40 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class UnCLIPScheduler(SchedulerMixin, ConfigMixin): """ NOTE: do not use this scheduler. The DDPM scheduler has been updated to support the changes made here. This @@ -100,6 +134,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) or `sample` (directly predicting the noisy sample`) + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ @register_to_config @@ -111,6 +149,7 @@ def __init__( clip_sample_range: Optional[float] = 1.0, prediction_type: str = "epsilon", beta_schedule: str = "squaredcos_cap_v2", + rescale_betas_zero_snr: bool = False, ): if beta_schedule != "squaredcos_cap_v2": raise ValueError("UnCLIPScheduler only supports `beta_schedule`: 'squaredcos_cap_v2'") @@ -121,6 +160,10 @@ def __init__( self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) + # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 2cce68f7d962..6ed33333a0e3 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -54,6 +54,37 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(alphas_bar_sqrt): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a @@ -117,6 +148,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): by disable the corrector at the first few steps (e.g., disable_corrector=[0]) solver_p (`SchedulerMixin`, default `None`): can be any other scheduler. If specified, the algorithm will become solver_p + UniC. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -140,6 +175,7 @@ def __init__( lower_order_final: bool = True, disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -163,6 +199,10 @@ def __init__( self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.alpha_t) + # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 From e70eac23eda4086ee269979eb299969a01b046d6 Mon Sep 17 00:00:00 2001 From: MaxWe00 Date: Sat, 3 Jun 2023 18:48:40 +0200 Subject: [PATCH 02/14] Implement rescale classifier free guidance in pipeline_stable_diffusion.py --- .../pipeline_stable_diffusion.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 170002b2514e..b8c5fc0027f4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -560,6 +560,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.7, ): r""" Function invoked when calling the pipeline for generation. @@ -620,6 +621,11 @@ def __call__( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. Examples: @@ -706,8 +712,20 @@ def __call__( noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + std_text = torch.std(noise_pred_text) + std_pred = torch.std(noise_pred) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_pred * (std_text / std_pred) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_pred_rescaled_final = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_pred + ) + # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred_rescaled_final, t, latents, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): From fe3cb4261b0104513e69ee143fd7d39614769898 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Jun 2023 12:52:56 +0000 Subject: [PATCH 03/14] focus on DDIM --- .../pipeline_stable_diffusion.py | 32 ++++++++----- src/diffusers/schedulers/scheduling_ddim.py | 27 +++++++---- .../schedulers/scheduling_ddim_inverse.py | 44 ------------------ src/diffusers/schedulers/scheduling_ddpm.py | 44 ------------------ .../schedulers/scheduling_deis_multistep.py | 42 +---------------- .../scheduling_dpmsolver_multistep.py | 44 ------------------ .../scheduling_dpmsolver_multistep_inverse.py | 44 ------------------ .../schedulers/scheduling_dpmsolver_sde.py | 43 ------------------ .../scheduling_dpmsolver_singlestep.py | 41 +---------------- .../scheduling_euler_ancestral_discrete.py | 44 +----------------- .../schedulers/scheduling_euler_discrete.py | 43 ------------------ .../schedulers/scheduling_heun_discrete.py | 43 ------------------ .../scheduling_k_dpm_2_ancestral_discrete.py | 43 ------------------ .../schedulers/scheduling_k_dpm_2_discrete.py | 44 ------------------ .../schedulers/scheduling_lms_discrete.py | 43 ------------------ src/diffusers/schedulers/scheduling_pndm.py | 45 +------------------ .../schedulers/scheduling_repaint.py | 44 +----------------- src/diffusers/schedulers/scheduling_unclip.py | 43 ------------------ .../schedulers/scheduling_unipc_multistep.py | 40 ----------------- 19 files changed, 45 insertions(+), 748 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b8c5fc0027f4..35d7adf8c58d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -54,6 +54,24 @@ ``` """ +def rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_pred` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + See Section 3.4 + """ + # std_text = torch.std(noise_pred_text) + # std_pred = torch.std(noise_pred) + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_pred = noise_pred.std(dim=list(range(1, noise_pred.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_pred * (std_text / std_pred) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_pred = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_pred + ) + return noise_pred + + class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin): r""" @@ -560,7 +578,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.7, + guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. @@ -712,19 +730,13 @@ def __call__( noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - std_text = torch.std(noise_pred_text) - std_pred = torch.std(noise_pred) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_pred * (std_text / std_pred) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_pred_rescaled_final = ( - guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_pred - ) + noise_pred = rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( - noise_pred_rescaled_final, t, latents, **extra_step_kwargs, return_dict=False + noise_pred, t, latents, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 76443449dea5..628d5f8531c5 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -76,10 +76,9 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): +def rescale_zero_terminal_snr(betas): """ - Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf - (Algorithm 1) + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: @@ -90,8 +89,11 @@ def rescale_zero_terminal_snr(alphas_cumprod): `torch.FloatTensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() + # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() @@ -182,6 +184,7 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, + timestep_type: str = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: @@ -199,6 +202,10 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) @@ -208,10 +215,6 @@ def __init__( # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -298,9 +301,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + # + if self.config.timestep_type == "leading": + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_type == "trailing": + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -self.config.num_train_timesteps/num_inference_steps)).astype(np.int64).copy() + timesteps -= 1 + self.timesteps = torch.from_numpy(timesteps).to(device) - self.timesteps += self.config.steps_offset def step( self, diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index 338448ebb1ad..2c9fc036a027 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -75,41 +75,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf - (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): """ DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`]. @@ -146,10 +111,6 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. """ order = 1 @@ -167,7 +128,6 @@ def __init__( steps_offset: int = 0, prediction_type: str = "epsilon", clip_sample_range: float = 1.0, - rescale_betas_zero_snr: bool = False, **kwargs, ): if kwargs.get("set_alpha_to_one", None) is not None: @@ -201,10 +161,6 @@ def __init__( # or whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1] - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 412483386c2e..5d24766d68c7 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -73,41 +73,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf - (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class DDPMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and @@ -149,10 +114,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -173,7 +134,6 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -198,10 +158,6 @@ def __init__( self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index a37c2fe5a1e5..8ea001a882d0 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -55,38 +55,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_bar_sqrt): - """ - Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf - (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): """ DEIS (https://arxiv.org/abs/2204.13902) is a fast high order solver for diffusion ODEs. We slightly modify the @@ -135,10 +103,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10. - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. + """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -160,7 +125,6 @@ def __init__( algorithm_type: str = "deis", solver_type: str = "logrho", lower_order_final: bool = True, - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -184,10 +148,6 @@ def __init__( self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alpha_t) - # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 60a7221fc15e..e72b1bdc23b5 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -55,41 +55,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf - (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with @@ -169,10 +134,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -197,7 +158,6 @@ def __init__( use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -221,10 +181,6 @@ def __init__( self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index d263ed540a22..b424ebbff262 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -55,41 +55,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf - (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): """ DPMSolverMultistepInverseScheduler is the reverse scheduler of [`DPMSolverMultistepScheduler`]. @@ -158,10 +123,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -186,7 +147,6 @@ def __init__( use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -213,10 +173,6 @@ def __init__( # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 76f63a4749ae..ae9229981152 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -105,40 +105,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): """ Implements Stochastic Sampler (Algorithm 2) from Karras et al. (2022). Based on the original k-diffusion @@ -167,10 +133,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. noise_sampler_seed (`int`, *optional*, defaults to `None`): The random seed to use for the noise sampler. If `None`, a random seed will be generated. - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -187,7 +149,6 @@ def __init__( prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -207,10 +168,6 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.use_karras_sigmas = use_karras_sigmas diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 327809a3c12e..7fa8eabb5a15 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -58,37 +58,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_bar_sqrt): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with @@ -163,10 +132,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. + """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -191,7 +157,6 @@ def __init__( use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -215,10 +180,6 @@ def __init__( self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alpha_t) - # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 70130791d23a..6b08e9bfc207 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -76,40 +76,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: @@ -133,10 +99,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. + """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -151,7 +114,6 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -171,10 +133,6 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index aaa1682e4767..7237128cbf07 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -76,40 +76,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original @@ -141,10 +107,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -161,7 +123,6 @@ def __init__( prediction_type: str = "epsilon", interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -181,10 +142,6 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index d8b10c6a09cc..100e2012ea20 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -52,40 +52,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original @@ -112,10 +78,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -131,7 +93,6 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -151,10 +112,6 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.use_karras_sigmas = use_karras_sigmas diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index c2c1b70f85e5..2fa0431e1292 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -53,40 +53,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: @@ -112,10 +78,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -130,7 +92,6 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -150,10 +111,6 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 3f7d7b9535e7..bb80c4a54bfe 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -52,41 +52,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf - (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): """ Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: @@ -112,10 +77,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -130,7 +91,6 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -150,10 +110,6 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 3d0a3adbe013..0656475c3093 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -74,40 +74,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by @@ -136,10 +102,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -155,7 +117,6 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -175,10 +136,6 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 9a32320b7385..01c02a21bbfc 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -54,41 +54,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf - (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class PNDMScheduler(SchedulerMixin, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, @@ -124,10 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. + """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -145,7 +107,6 @@ def __init__( set_alpha_to_one: bool = False, prediction_type: str = "epsilon", steps_offset: int = 0, - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -167,10 +128,6 @@ def __init__( self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index c262554c7c40..f2f97b38f3d3 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -72,40 +72,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class RePaintScheduler(SchedulerMixin, ConfigMixin): """ RePaint is a schedule for DDPM inpainting inside a given mask. @@ -134,10 +100,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. + """ order = 1 @@ -152,7 +115,6 @@ def __init__( eta: float = 0.0, trained_betas: Optional[np.ndarray] = None, clip_sample: bool = True, - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -179,10 +141,6 @@ def __init__( self.final_alpha_cumprod = torch.tensor(1.0) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py index 43c9cd0668bc..d44edcb1812a 100644 --- a/src/diffusers/schedulers/scheduling_unclip.py +++ b/src/diffusers/schedulers/scheduling_unclip.py @@ -73,40 +73,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class UnCLIPScheduler(SchedulerMixin, ConfigMixin): """ NOTE: do not use this scheduler. The DDPM scheduler has been updated to support the changes made here. This @@ -134,10 +100,6 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) or `sample` (directly predicting the noisy sample`) - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. """ @register_to_config @@ -149,7 +111,6 @@ def __init__( clip_sample_range: Optional[float] = 1.0, prediction_type: str = "epsilon", beta_schedule: str = "squaredcos_cap_v2", - rescale_betas_zero_snr: bool = False, ): if beta_schedule != "squaredcos_cap_v2": raise ValueError("UnCLIPScheduler only supports `beta_schedule`: 'squaredcos_cap_v2'") @@ -160,10 +121,6 @@ def __init__( self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alphas_cumprod) - # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 6ed33333a0e3..2cce68f7d962 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -54,37 +54,6 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_bar_sqrt): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a @@ -148,10 +117,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): by disable the corrector at the first few steps (e.g., disable_corrector=[0]) solver_p (`SchedulerMixin`, default `None`): can be any other scheduler. If specified, the algorithm will become solver_p + UniC. - rescale_betas_zero_snr (`bool`, default `False`): - whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). - This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -175,7 +140,6 @@ def __init__( lower_order_final: bool = True, disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, - rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -199,10 +163,6 @@ def __init__( self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.alpha_t) - # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 From 5670826fcb43e32a2e3c6c383aaf9b5493c8c8da Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Jun 2023 12:53:51 +0000 Subject: [PATCH 04/14] make style --- .../stable_diffusion/pipeline_stable_diffusion.py | 14 +++++--------- src/diffusers/schedulers/scheduling_ddim.py | 11 +++++++++-- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 35d7adf8c58d..48e08b7f1853 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -54,10 +54,11 @@ ``` """ + def rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=0.0): """ - Rescale `noise_pred` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - See Section 3.4 + Rescale `noise_pred` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ # std_text = torch.std(noise_pred_text) # std_pred = torch.std(noise_pred) @@ -66,13 +67,10 @@ def rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=0.0): # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_pred * (std_text / std_pred) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_pred = ( - guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_pred - ) + noise_pred = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_pred return noise_pred - class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -735,9 +733,7 @@ def __call__( noise_pred = rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 628d5f8531c5..2a7a79ff6c0a 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -93,7 +93,6 @@ def rescale_zero_terminal_snr(betas): alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() - # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() @@ -306,7 +305,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_type == "trailing": - timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -self.config.num_train_timesteps/num_inference_steps)).astype(np.int64).copy() + timesteps = ( + np.round( + np.arange( + self.config.num_train_timesteps, 0, -self.config.num_train_timesteps / num_inference_steps + ) + ) + .astype(np.int64) + .copy() + ) timesteps -= 1 self.timesteps = torch.from_numpy(timesteps).to(device) From db5ff8294232232e0b1aa4b55a600f5fc3865152 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Jun 2023 12:56:57 +0000 Subject: [PATCH 05/14] make style --- src/diffusers/schedulers/scheduling_ddim.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 2a7a79ff6c0a..d520df63fd40 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -300,21 +300,18 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - # + + # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 if self.config.timestep_type == "leading": timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_type == "trailing": - timesteps = ( - np.round( - np.arange( - self.config.num_train_timesteps, 0, -self.config.num_train_timesteps / num_inference_steps - ) - ) - .astype(np.int64) - .copy() - ) + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64).copy() timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_type} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) self.timesteps = torch.from_numpy(timesteps).to(device) From 6c62ff0d09dbb1bca6b15959cf9426ebba64ca14 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Jun 2023 13:00:08 +0000 Subject: [PATCH 06/14] make style --- src/diffusers/schedulers/scheduling_ddim.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index d520df63fd40..8bd80a5e1d42 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -158,6 +158,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_scaling (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. rescale_betas_zero_snr (`bool`, default `False`): whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). This can enable the model to generate very bright and dark samples instead of limiting it to samples with @@ -183,7 +186,7 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_type: str = "leading", + timestep_scaling: str = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: @@ -302,15 +305,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # casting to int to avoid issues when num_inference_step is power of 3 # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_type == "leading": + if self.config.timestep_scaling == "leading": timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset - elif self.config.timestep_type == "trailing": + elif self.config.timestep_scaling == "trailing": timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64).copy() timesteps -= 1 else: raise ValueError( - f"{self.config.timestep_type} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + f"{self.config.timestep_scaling} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) From e4aef4d7267ce80a86181cc3fed1e3360a8257cc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Jun 2023 13:03:01 +0000 Subject: [PATCH 07/14] make style --- .../alt_diffusion/pipeline_alt_diffusion.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 8507684cf9b4..89f2066400cb 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -51,6 +51,23 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_pred +def rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_pred` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + # std_text = torch.std(noise_pred_text) + # std_pred = torch.std(noise_pred) + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_pred = noise_pred.std(dim=list(range(1, noise_pred.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_pred * (std_text / std_pred) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_pred = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_pred + return noise_pred + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" @@ -559,6 +576,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. @@ -619,6 +637,11 @@ def __call__( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. Examples: @@ -705,6 +728,10 @@ def __call__( noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] From 9cda72cc45dae9a8e378cff591ecba8af9c78a62 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 6 Jun 2023 22:30:48 +0000 Subject: [PATCH 08/14] Apply suggestions from Peter Lin --- .../pipeline_stable_diffusion.py | 14 +++++++------- src/diffusers/schedulers/scheduling_ddim.py | 19 +++++++++++-------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 48e08b7f1853..201f146c55f5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -55,20 +55,20 @@ """ -def rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=0.0): +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ - Rescale `noise_pred` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ # std_text = torch.std(noise_pred_text) # std_pred = torch.std(noise_pred) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_pred = noise_pred.std(dim=list(range(1, noise_pred.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_pred * (std_text / std_pred) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_pred = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_pred - return noise_pred + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin): @@ -730,7 +730,7 @@ def __call__( if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 8bd80a5e1d42..889c8dcc68c7 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -158,7 +158,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. - timestep_scaling (`str`, default `"leading"`): + timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. rescale_betas_zero_snr (`bool`, default `False`): @@ -186,7 +186,7 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_scaling: str = "leading", + timestep_spacing: str = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: @@ -300,20 +300,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) self.num_inference_steps = num_inference_steps - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_scaling == "leading": + if self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset - elif self.config.timestep_scaling == "trailing": + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64).copy() timesteps -= 1 else: raise ValueError( - f"{self.config.timestep_scaling} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) From a2fbd41e241e68866fdca62efff7dd7abbd53422 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 6 Jun 2023 23:16:14 +0000 Subject: [PATCH 09/14] Apply suggestions from Peter Lin --- .../alt_diffusion/pipeline_alt_diffusion.py | 6 +++--- .../test_stable_diffusion.py | 19 +++++++++++++++++ .../test_stable_diffusion_v_pred.py | 21 +++++++++++++++++++ tests/schedulers/test_scheduler_ddim.py | 8 +++++++ 4 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 89f2066400cb..fd274dc38f24 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -51,8 +51,8 @@ """ -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_pred -def rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=0.0): +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_pred` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 @@ -730,7 +730,7 @@ def __call__( if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 3f9867783b33..bf3da56cd73f 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -207,6 +207,25 @@ def test_stable_diffusion_k_euler(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_unflawed(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + components["scheduler"] = DDIMScheduler.from_config(components["scheduler"].config, timestep_spacing="trailing") + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["guidance_rescale"] = 0.7 + inputs["num_inference_steps"] = 10 + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4736, 0.5405, 0.4705, 0.4955, 0.5675, 0.4812, 0.5310, 0.4967, 0.5064]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_long_prompt(self): components = self.get_dummy_components() components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index d1a2c856659f..a6ccc24567b8 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -384,6 +384,27 @@ def test_stable_diffusion_text2img_pipeline_v_pred_default(self): assert image.shape == (768, 768, 3) assert np.abs(expected_image - image).max() < 9e-1 + def test_stable_diffusion_text2img_pipeline_unflawed(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" + "sd2-text2img/lion_galaxy.npy" + ) + + pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1") + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing", rescale_betas_zero_snr=True) + pipe.to(torch_device) + pipe.enable_attention_slicing() + pipe.set_progress_bar_config(disable=None) + + prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" + + generator = torch.manual_seed(0) + output = pipe(prompt=prompt, guidance_scale=7.5, guidance_rescale=0.7, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (768, 768, 3) + assert np.abs(expected_image - image).max() < 5e-1 + def test_stable_diffusion_text2img_pipeline_v_pred_fp16(self): expected_image = load_numpy( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" diff --git a/tests/schedulers/test_scheduler_ddim.py b/tests/schedulers/test_scheduler_ddim.py index e9c85314d558..156b02b2208e 100644 --- a/tests/schedulers/test_scheduler_ddim.py +++ b/tests/schedulers/test_scheduler_ddim.py @@ -69,6 +69,14 @@ def test_clip_sample(self): for clip_sample in [True, False]: self.check_over_configs(clip_sample=clip_sample) + def test_timestep_spacing(self): + for timestep_spacing in ["trailing", "leading"]: + self.check_over_configs(timestep_spacing=timestep_spacing) + + def test_rescale_betas_zero_snr(self): + for rescale_betas_zero_snr in [True, False]: + self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_thresholding(self): self.check_over_configs(thresholding=False) for threshold in [0.5, 1.0, 2.0]: From ca956582da69ea6c06a2d68083aa0574e5d653b8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 6 Jun 2023 23:20:06 +0000 Subject: [PATCH 10/14] make style --- .../alt_diffusion/pipeline_alt_diffusion.py | 12 ++++++------ .../stable_diffusion_2/test_stable_diffusion.py | 4 +++- .../test_stable_diffusion_v_pred.py | 8 +++++--- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 49dbaf3cc73f..bf8fc42eeaeb 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -52,20 +52,20 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg -def rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=0.0): +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ - Rescale `noise_pred` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ # std_text = torch.std(noise_pred_text) # std_pred = torch.std(noise_pred) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_pred = noise_pred.std(dim=list(range(1, noise_pred.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_pred * (std_text / std_pred) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_pred = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_pred - return noise_pred + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 22bde6d6e311..33cc7f638ec2 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -211,7 +211,9 @@ def test_stable_diffusion_k_euler(self): def test_stable_diffusion_unflawed(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() - components["scheduler"] = DDIMScheduler.from_config(components["scheduler"].config, timestep_spacing="trailing") + components["scheduler"] = DDIMScheduler.from_config( + components["scheduler"].config, timestep_spacing="trailing" + ) sd_pipe = StableDiffusionPipeline(**components) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index a6ccc24567b8..21862ba6a216 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -386,12 +386,14 @@ def test_stable_diffusion_text2img_pipeline_v_pred_default(self): def test_stable_diffusion_text2img_pipeline_unflawed(self): expected_image = load_numpy( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" - "sd2-text2img/lion_galaxy.npy" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" + "sd2-text2img/lion_galaxy.npy" ) pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1") - pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing", rescale_betas_zero_snr=True) + pipe.scheduler = DDIMScheduler.from_config( + pipe.scheduler.config, timestep_scaling="trailing", rescale_betas_zero_snr=True + ) pipe.to(torch_device) pipe.enable_attention_slicing() pipe.set_progress_bar_config(disable=None) From 147ca8e25d256aa65d30751eccf21b013b6035f8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 7 Jun 2023 09:48:46 +0100 Subject: [PATCH 11/14] Apply suggestions from code review --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b0dc9a3239ad..8368668ebea7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -60,8 +60,6 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - # std_text = torch.std(noise_pred_text) - # std_pred = torch.std(noise_pred) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) From 65532cdfb6ef54165c9c26a64601e57d0715747d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 7 Jun 2023 09:49:27 +0100 Subject: [PATCH 12/14] Apply suggestions from code review --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 889c8dcc68c7..75c9d472cd86 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -312,7 +312,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64).copy() + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( From 03282f85013a668889fcc0aa1e59eaf12c6d4c91 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 7 Jun 2023 08:52:32 +0000 Subject: [PATCH 13/14] make style --- .../pipelines/alt_diffusion/pipeline_alt_diffusion.py | 2 -- src/diffusers/schedulers/scheduling_ddim.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index bf8fc42eeaeb..b79e4f72144b 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -57,8 +57,6 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - # std_text = torch.std(noise_pred_text) - # std_pred = torch.std(noise_pred) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 75c9d472cd86..bab6f8acea03 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -164,7 +164,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): rescale_betas_zero_snr (`bool`, default `False`): whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). This can enable the model to generate very bright and dark samples instead of limiting it to samples with - medium brightness. + medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] From eb2e2585d8ede7299dc2924cc244af1f6ef282bd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 7 Jun 2023 09:34:53 +0000 Subject: [PATCH 14/14] make style --- .../stable_diffusion/stable_diffusion_2.mdx | 58 +++++++++++++++++ docs/source/en/api/schedulers/ddim.mdx | 63 ++++++++++++++++++- examples/text_to_image/train_text_to_image.py | 10 +++ .../text_to_image/train_text_to_image_lora.py | 10 +++ 4 files changed, 140 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.mdx b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.mdx index e922072e4e31..7162626ebbde 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.mdx @@ -71,6 +71,64 @@ image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0] image.save("astronaut.png") ``` +#### Experimental: "Common Diffusion Noise Schedules and Sample Steps are Flawed": + +The paper **[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891)** +claims that a mismatch between the training and inference settings leads to suboptimal inference generation results for Stable Diffusion. + +The abstract reads as follows: + +*We discover that common diffusion noise schedules do not enforce the last timestep to have zero signal-to-noise ratio (SNR), +and some implementations of diffusion samplers do not start from the last timestep. +Such designs are flawed and do not reflect the fact that the model is given pure Gaussian noise at inference, creating a discrepancy between training and inference. +We show that the flawed design causes real problems in existing implementations. +In Stable Diffusion, it severely limits the model to only generate images with medium brightness and +prevents it from generating very bright and dark samples. We propose a few simple fixes: +- (1) rescale the noise schedule to enforce zero terminal SNR; +- (2) train the model with v prediction; +- (3) change the sampler to always start from the last timestep; +- (4) rescale classifier-free guidance to prevent over-exposure. +These simple changes ensure the diffusion process is congruent between training and inference and +allow the model to generate samples more faithful to the original data distribution.* + +You can apply all of these changes in `diffusers` when using [`DDIMScheduler`]: +- (1) rescale the noise schedule to enforce zero terminal SNR; +```py +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True) +``` +- (2) train the model with v prediction; +Continue fine-tuning a checkpoint with [`train_text_to_image.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [`train_text_to_image_lora.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) +and `--prediction_type="v_prediction"`. +- (3) change the sampler to always start from the last timestep; +```py +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing") +``` +- (4) rescale classifier-free guidance to prevent over-exposure. +```py +pipe(..., guidance_rescale=0.7) +``` + +An example is to use [this checkpoint](https://huggingface.co/ptx0/pseudo-journey-v2) +which has been fine-tuned using the `"v_prediction"`. + +The checkpoint can then be run in inference as follows: + +```py +from diffusers import DiffusionPipeline, DDIMScheduler + +pipe = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", torch_dtype=torch.float16) +pipe.scheduler = DDIMScheduler.from_config( + pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_scaling="trailing" +) +pipe.to("cuda") + +prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" +image = pipeline(prompt, guidance_rescale=0.7).images[0] +``` + +## DDIMScheduler +[[autodoc]] DDIMScheduler + ### Image Inpainting - *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`] diff --git a/docs/source/en/api/schedulers/ddim.mdx b/docs/source/en/api/schedulers/ddim.mdx index 51b0cc3e9a09..0db5e4f4e2b5 100644 --- a/docs/source/en/api/schedulers/ddim.mdx +++ b/docs/source/en/api/schedulers/ddim.mdx @@ -18,10 +18,71 @@ specific language governing permissions and limitations under the License. The abstract of the paper is the following: -Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space. +*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, +yet they require simulating a Markov chain for many steps to produce a sample. +To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models +with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. +We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. +We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off +computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.* The original codebase of this paper can be found here: [ermongroup/ddim](https://github.com/ermongroup/ddim). For questions, feel free to contact the author on [tsong.me](https://tsong.me/). +### Experimental: "Common Diffusion Noise Schedules and Sample Steps are Flawed": + +The paper **[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891)** +claims that a mismatch between the training and inference settings leads to suboptimal inference generation results for Stable Diffusion. + +The abstract reads as follows: + +*We discover that common diffusion noise schedules do not enforce the last timestep to have zero signal-to-noise ratio (SNR), +and some implementations of diffusion samplers do not start from the last timestep. +Such designs are flawed and do not reflect the fact that the model is given pure Gaussian noise at inference, creating a discrepancy between training and inference. +We show that the flawed design causes real problems in existing implementations. +In Stable Diffusion, it severely limits the model to only generate images with medium brightness and +prevents it from generating very bright and dark samples. We propose a few simple fixes: +- (1) rescale the noise schedule to enforce zero terminal SNR; +- (2) train the model with v prediction; +- (3) change the sampler to always start from the last timestep; +- (4) rescale classifier-free guidance to prevent over-exposure. +These simple changes ensure the diffusion process is congruent between training and inference and +allow the model to generate samples more faithful to the original data distribution.* + +You can apply all of these changes in `diffusers` when using [`DDIMScheduler`]: +- (1) rescale the noise schedule to enforce zero terminal SNR; +```py +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True) +``` +- (2) train the model with v prediction; +Continue fine-tuning a checkpoint with [`train_text_to_image.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [`train_text_to_image_lora.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) +and `--prediction_type="v_prediction"`. +- (3) change the sampler to always start from the last timestep; +```py +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing") +``` +- (4) rescale classifier-free guidance to prevent over-exposure. +```py +pipe(..., guidance_rescale=0.7) +``` + +An example is to use [this checkpoint](https://huggingface.co/ptx0/pseudo-journey-v2) +which has been fine-tuned using the `"v_prediction"`. + +The checkpoint can then be run in inference as follows: + +```py +from diffusers import DiffusionPipeline, DDIMScheduler + +pipe = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", torch_dtype=torch.float16) +pipe.scheduler = DDIMScheduler.from_config( + pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_scaling="trailing" +) +pipe.to("cuda") + +prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" +image = pipeline(prompt, guidance_rescale=0.7).images[0] +``` + ## DDIMScheduler [[autodoc]] DDIMScheduler diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index bbf7bf9b85bb..0965c77eea96 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -307,6 +307,12 @@ def parse_args(): parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) parser.add_argument( "--hub_model_id", type=str, @@ -848,6 +854,10 @@ def collate_fn(examples): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 806637f04c53..30d527efd22d 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -272,6 +272,12 @@ def parse_args(): parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) parser.add_argument( "--hub_model_id", type=str, @@ -749,6 +755,10 @@ def collate_fn(examples): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction":