@@ -108,7 +108,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
108108 The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
109109 `algorithm_type="dpmsolver++"`.
110110 algorithm_type (`str`, defaults to `dpmsolver++`):
111- Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde- dpmsolver++`. The
111+ Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The
112112 `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
113113 paper, and the `dpmsolver++` type implements the algorithms in the
114114 [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
@@ -122,6 +122,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
122122 use_karras_sigmas (`bool`, *optional*, defaults to `False`):
123123 Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
124124 the sigmas are determined according to a sequence of noise levels {σi}.
125+ final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
126+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
127+ is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
125128 lambda_min_clipped (`float`, defaults to `-inf`):
126129 Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
127130 cosine (`squaredcos_cap_v2`) noise schedule.
@@ -150,9 +153,14 @@ def __init__(
150153 solver_type : str = "midpoint" ,
151154 lower_order_final : bool = True ,
152155 use_karras_sigmas : Optional [bool ] = False ,
156+ final_sigmas_type : Optional [str ] = "zero" , # "zero", "sigma_min"
153157 lambda_min_clipped : float = - float ("inf" ),
154158 variance_type : Optional [str ] = None ,
155159 ):
160+ if algorithm_type == "dpmsolver" :
161+ deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
162+ deprecate ("algorithm_types=dpmsolver" , "1.0.0" , deprecation_message )
163+
156164 if trained_betas is not None :
157165 self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
158166 elif beta_schedule == "linear" :
@@ -189,6 +197,11 @@ def __init__(
189197 else :
190198 raise NotImplementedError (f"{ solver_type } does is not implemented for { self .__class__ } " )
191199
200+ if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero" :
201+ raise ValueError (
202+ f"`final_sigmas_type` { final_sigmas_type } is not supported for `algorithm_type` { algorithm_type } . Please chooose `sigma_min` instead."
203+ )
204+
192205 # setable values
193206 self .num_inference_steps = None
194207 timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = np .float32 )[::- 1 ].copy ()
@@ -267,11 +280,18 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
267280 sigmas = np .flip (sigmas ).copy ()
268281 sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
269282 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ]).round ()
270- sigmas = np .concatenate ([sigmas , sigmas [- 1 :]]).astype (np .float32 )
271283 else :
272284 sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
285+
286+ if self .config .final_sigmas_type == "sigma_min" :
273287 sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
274- sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
288+ elif self .config .final_sigmas_type == "zero" :
289+ sigma_last = 0
290+ else :
291+ raise ValueError (
292+ f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got { self .config .final_sigmas_type } "
293+ )
294+ sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
275295
276296 self .sigmas = torch .from_numpy (sigmas ).to (device = device )
277297
@@ -285,6 +305,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
285305 )
286306 self .register_to_config (lower_order_final = True )
287307
308+ if not self .config .lower_order_final and self .config .final_sigmas_type == "zero" :
309+ logger .warn (
310+ " `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True."
311+ )
312+ self .register_to_config (lower_order_final = True )
313+
288314 self .order_list = self .get_order_list (num_inference_steps )
289315
290316 # add an index counter for schedulers that allow duplicated timesteps
0 commit comments