@@ -104,13 +104,12 @@ def __init__(
104104 beta_end : float = 0.02 ,
105105 beta_schedule : str = "linear" ,
106106 trained_betas : Optional [np .ndarray ] = None ,
107- steps_offset : int = 0 ,
108- solver_order = 3 ,
109- predict_x0 = True ,
110- thresholding = False ,
111- sample_max_value = 1.0 ,
112- solver_type = "dpm_solver" ,
113- denoise_final = True ,
107+ solver_order : int = 3 ,
108+ predict_x0 : bool = True ,
109+ thresholding : bool = False ,
110+ sample_max_value : float = 1.0 ,
111+ solver_type : str = "dpm_solver" ,
112+ denoise_final : bool = True ,
114113 ):
115114 if trained_betas is not None :
116115 self .betas = torch .from_numpy (trained_betas )
@@ -150,7 +149,7 @@ def __init__(
150149
151150 # setable values
152151 self .num_inference_steps = None
153- timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = float )[::- 1 ].copy ()
152+ timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = np . float32 )[::- 1 ].copy ()
154153 self .timesteps = torch .from_numpy (timesteps )
155154 self .model_outputs = [None ,] * self .solver_order
156155 self .lower_order_nums = 0
@@ -185,7 +184,7 @@ def convert_model_output(
185184 x0_pred = (sample - sigma_t * model_output ) / alpha_t
186185 if self .thresholding :
187186 # Dynamic thresholding in https://arxiv.org/abs/2205.11487
188- p = 0.995 # A hyperparameter in the paper of "Imagen" [1] .
187+ p = 0.995 # A hyperparameter in the paper of "Imagen" (https://arxiv.org/abs/2205.11487) .
189188 s = torch .quantile (torch .abs (x0_pred ).reshape ((x0_pred .shape [0 ], - 1 )), p , dim = 1 )
190189 s = torch .maximum (s , self .sample_max_value * torch .ones_like (s ).to (s .device ))[(...,) + (None ,)* (x0_pred .ndim - 1 )]
191190 x0_pred = torch .clamp (x0_pred , - s , s ) / s
@@ -338,7 +337,6 @@ def step(
338337 denoise_second = (step_index == len (self .timesteps ) - 2 ) and self .denoise_final
339338
340339 model_output = self .convert_model_output (model_output , timestep , sample )
341- self .model_outputs .append (model_output )
342340 for i in range (self .solver_order - 1 ):
343341 self .model_outputs [i ] = self .model_outputs [i + 1 ]
344342 self .model_outputs [- 1 ] = model_output
0 commit comments