1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import math
1516from dataclasses import dataclass
1617from typing import Optional , Tuple , Union
1718
@@ -65,6 +66,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
6566 range is [0.2, 80.0].
6667 sigma_data (`float`, *optional*, defaults to 0.5):
6768 The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
69+ sigma_schedule (`str`, *optional*, defaults to `karras`):
70+ Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
71+ (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
72+ incorporated in this model: https://huggingface.co/stabilityai/cosxl.
6873 num_train_timesteps (`int`, defaults to 1000):
6974 The number of diffusion steps to train the model.
7075 prediction_type (`str`, defaults to `epsilon`, *optional*):
@@ -84,15 +89,23 @@ def __init__(
8489 sigma_min : float = 0.002 ,
8590 sigma_max : float = 80.0 ,
8691 sigma_data : float = 0.5 ,
92+ sigma_schedule : str = "karras" ,
8793 num_train_timesteps : int = 1000 ,
8894 prediction_type : str = "epsilon" ,
8995 rho : float = 7.0 ,
9096 ):
97+ if sigma_schedule not in ["karras" , "exponential" ]:
98+ raise ValueError (f"Wrong value for provided for `{ sigma_schedule = } `.`" )
99+
91100 # setable values
92101 self .num_inference_steps = None
93102
94103 ramp = torch .linspace (0 , 1 , num_train_timesteps )
95- sigmas = self ._compute_sigmas (ramp )
104+ if sigma_schedule == "karras" :
105+ sigmas = self ._compute_karras_sigmas (ramp )
106+ elif sigma_schedule == "exponential" :
107+ sigmas = self ._compute_exponential_sigmas (ramp )
108+
96109 self .timesteps = self .precondition_noise (sigmas )
97110
98111 self .sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
@@ -200,7 +213,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
200213 self .num_inference_steps = num_inference_steps
201214
202215 ramp = np .linspace (0 , 1 , self .num_inference_steps )
203- sigmas = self ._compute_sigmas (ramp )
216+ if self .config .sigma_schedule == "karras" :
217+ sigmas = self ._compute_karras_sigmas (ramp )
218+ elif self .config .sigma_schedule == "exponential" :
219+ sigmas = self ._compute_exponential_sigmas (ramp )
204220
205221 sigmas = torch .from_numpy (sigmas ).to (dtype = torch .float32 , device = device )
206222 self .timesteps = self .precondition_noise (sigmas )
@@ -211,16 +227,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
211227 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
212228
213229 # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
214- def _compute_sigmas (self , ramp , sigma_min = None , sigma_max = None ) -> torch .FloatTensor :
230+ def _compute_karras_sigmas (self , ramp , sigma_min = None , sigma_max = None ) -> torch .FloatTensor :
215231 """Constructs the noise schedule of Karras et al. (2022)."""
216-
217232 sigma_min = sigma_min or self .config .sigma_min
218233 sigma_max = sigma_max or self .config .sigma_max
219234
220235 rho = self .config .rho
221236 min_inv_rho = sigma_min ** (1 / rho )
222237 max_inv_rho = sigma_max ** (1 / rho )
223238 sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
239+
240+ return sigmas
241+
242+ def _compute_exponential_sigmas (self , ramp , sigma_min = None , sigma_max = None ) -> torch .FloatTensor :
243+ """Implementation closely follows k-diffusion.
244+
245+ https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
246+ """
247+ sigma_min = sigma_min or self .config .sigma_min
248+ sigma_max = sigma_max or self .config .sigma_max
249+ sigmas = torch .linspace (math .log (sigma_min ), math .log (sigma_max ), len (ramp )).exp ().flip (0 )
224250 return sigmas
225251
226252 # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
0 commit comments