Skip to content

Commit 904e977

Browse files
committed
fix some typos in dpm-solver pytorch
1 parent f0141e2 commit 904e977

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_discrete.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,12 @@ def __init__(
104104
beta_end: float = 0.02,
105105
beta_schedule: str = "linear",
106106
trained_betas: Optional[np.ndarray] = None,
107-
steps_offset: int = 0,
108-
solver_order=3,
109-
predict_x0=True,
110-
thresholding=False,
111-
sample_max_value=1.0,
112-
solver_type="dpm_solver",
113-
denoise_final=True,
107+
solver_order: int = 3,
108+
predict_x0: bool = True,
109+
thresholding: bool = False,
110+
sample_max_value: float = 1.0,
111+
solver_type: str = "dpm_solver",
112+
denoise_final: bool = True,
114113
):
115114
if trained_betas is not None:
116115
self.betas = torch.from_numpy(trained_betas)
@@ -150,7 +149,7 @@ def __init__(
150149

151150
# setable values
152151
self.num_inference_steps = None
153-
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
152+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
154153
self.timesteps = torch.from_numpy(timesteps)
155154
self.model_outputs = [None,] * self.solver_order
156155
self.lower_order_nums = 0
@@ -185,7 +184,7 @@ def convert_model_output(
185184
x0_pred = (sample - sigma_t * model_output) / alpha_t
186185
if self.thresholding:
187186
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
188-
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
187+
p = 0.995 # A hyperparameter in the paper of "Imagen" (https://arxiv.org/abs/2205.11487).
189188
s = torch.quantile(torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), p, dim=1)
190189
s = torch.maximum(s, self.sample_max_value * torch.ones_like(s).to(s.device))[(...,) + (None,)*(x0_pred.ndim - 1)]
191190
x0_pred = torch.clamp(x0_pred, -s, s) / s
@@ -338,7 +337,6 @@ def step(
338337
denoise_second = (step_index == len(self.timesteps) - 2) and self.denoise_final
339338

340339
model_output = self.convert_model_output(model_output, timestep, sample)
341-
self.model_outputs.append(model_output)
342340
for i in range(self.solver_order - 1):
343341
self.model_outputs[i] = self.model_outputs[i + 1]
344342
self.model_outputs[-1] = model_output

0 commit comments

Comments
 (0)