diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 3c0955ce4ea4..70e63a64c0a8 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -127,6 +127,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma + is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -153,6 +156,7 @@ def __init__( use_karras_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -265,10 +269,25 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index be41cea95b67..fa34ef75b52c 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -24,6 +24,7 @@ def get_scheduler_config(self, **kwargs): "beta_schedule": "linear", "solver_order": 2, "solver_type": "bh2", + "final_sigmas_type": "sigma_min", } config.update(**kwargs)