1616
1717import math
1818from typing import Optional , Tuple , Union
19+ import warnings
1920
2021import numpy as np
2122import torch
@@ -73,10 +74,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
7374 the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
7475 `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
7576 trained_betas (`np.ndarray`, optional): TODO
76- tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
7777 skip_prk_steps (`bool`):
7878 allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
7979 before plms steps; defaults to `False`.
80+ set_alpha_to_one (`bool`, default `True`):
81+ each diffusion step uses the value of alphas product at that step and at the previous one.
82+ For the final step there is no previous alpha. When this option is `True` the previous alpha
83+ product is fixed to `1`, otherwise it uses the value of alpha at step 0.
84+ steps_offset (`int`, default `0`):
85+ an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`,
86+ to make the last step use step 0 for the previous alpha product.
87+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
8088
8189 """
8290
@@ -88,8 +96,10 @@ def __init__(
8896 beta_end : float = 0.02 ,
8997 beta_schedule : str = "linear" ,
9098 trained_betas : Optional [np .ndarray ] = None ,
91- tensor_format : str = "pt" ,
9299 skip_prk_steps : bool = False ,
100+ set_alpha_to_one : bool = True ,
101+ steps_offset : int = 0 ,
102+ tensor_format : str = "pt" ,
93103 ):
94104 if trained_betas is not None :
95105 self .betas = np .asarray (trained_betas )
@@ -107,6 +117,8 @@ def __init__(
107117 self .alphas = 1.0 - self .betas
108118 self .alphas_cumprod = np .cumprod (self .alphas , axis = 0 )
109119
120+ self .final_alpha_cumprod = np .array (1.0 ) if set_alpha_to_one else self .alphas_cumprod [0 ]
121+
110122 self .one = np .array (1.0 )
111123
112124 # For now we only support F-PNDM, i.e. the runge-kutta method
@@ -123,29 +135,37 @@ def __init__(
123135 # setable values
124136 self .num_inference_steps = None
125137 self ._timesteps = np .arange (0 , num_train_timesteps )[::- 1 ].copy ()
126- self ._offset = 0
127138 self .prk_timesteps = None
128139 self .plms_timesteps = None
129140 self .timesteps = None
130141
131142 self .tensor_format = tensor_format
132143 self .set_format (tensor_format = tensor_format )
133144
134- def set_timesteps (self , num_inference_steps : int , offset : int = 0 ) -> torch .FloatTensor :
145+ def set_timesteps (self , num_inference_steps : int , ** kwargs ) -> torch .FloatTensor :
135146 """
136147 Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
137148
138149 Args:
139150 num_inference_steps (`int`):
140151 the number of diffusion steps used when generating samples with a pre-trained model.
141- offset (`int`): TODO
142152 """
153+
154+ offset = self .config .steps_offset
155+
156+ if "offset" in kwargs :
157+ warnings .warn (
158+ "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
159+ " Please pass `steps_offset` to `__init__` instead."
160+ )
161+
162+ offset = kwargs ["offset" ]
163+
143164 self .num_inference_steps = num_inference_steps
144165 self ._timesteps = list (
145166 range (0 , self .config .num_train_timesteps , self .config .num_train_timesteps // num_inference_steps )
146167 )
147- self ._offset = offset
148- self ._timesteps = np .array ([t + self ._offset for t in self ._timesteps ])
168+ self ._timesteps = np .array (self ._timesteps ) + offset
149169
150170 if self .config .skip_prk_steps :
151171 # for some models like stable diffusion the prk steps can/should be skipped to
@@ -322,7 +342,7 @@ def step_plms(
322342
323343 return SchedulerOutput (prev_sample = prev_sample )
324344
325- def _get_prev_sample (self , sample , timestep , timestep_prev , model_output ):
345+ def _get_prev_sample (self , sample , timestep , prev_timestep , model_output ):
326346 # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
327347 # this function computes x_(t−δ) using the formula of (9)
328348 # Note that x_t needs to be added to both sides of the equation
@@ -335,8 +355,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
335355 # sample -> x_t
336356 # model_output -> e_θ(x_t, t)
337357 # prev_sample -> x_(t−δ)
338- alpha_prod_t = self .alphas_cumprod [timestep + 1 - self . _offset ]
339- alpha_prod_t_prev = self .alphas_cumprod [timestep_prev + 1 - self ._offset ]
358+ alpha_prod_t = self .alphas_cumprod [timestep ]
359+ alpha_prod_t_prev = self .alphas_cumprod [prev_timestep ] if prev_timestep >= 0 else self .final_alpha_cumprod
340360 beta_prod_t = 1 - alpha_prod_t
341361 beta_prod_t_prev = 1 - alpha_prod_t_prev
342362
0 commit comments