Skip to content

Commit 729273d

Browse files
committed
Fix the LMS pytorch regression
1 parent 235770d commit 729273d

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,25 +85,26 @@ def __init__(
8585
)
8686

8787
if trained_betas is not None:
88-
self.betas = torch.from_numpy(trained_betas)
88+
self.betas = np.asarray(trained_betas)
8989
if beta_schedule == "linear":
90-
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
90+
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
9191
elif beta_schedule == "scaled_linear":
9292
# this schedule is very specific to the latent diffusion model.
9393
self.betas = (
94-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
94+
np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
9595
)
9696
else:
9797
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
9898

99-
self.alphas = 1.0 - self.betas
100-
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
99+
self.alphas = np.array(1.0 - self.betas, dtype=np.float32)
100+
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
101101

102-
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
102+
sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
103+
self.sigmas = torch.from_numpy(sigmas)
103104

104105
# setable values
105106
self.num_inference_steps = None
106-
self.timesteps = np.arange(0, num_train_timesteps)[::-1] # to be consistent has to be smaller than sigmas by 1
107+
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
107108
self.derivatives = []
108109

109110
def get_lms_coefficient(self, order, t, current_order):
@@ -146,8 +147,8 @@ def set_timesteps(self, num_inference_steps: int):
146147
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
147148
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
148149
self.sigmas = torch.from_numpy(sigmas)
150+
self.timesteps = timesteps
149151

150-
self.timesteps = timesteps.astype(int)
151152
self.derivatives = []
152153

153154
def step(

tests/test_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,5 +876,5 @@ def test_full_loop_no_noise(self):
876876
result_sum = torch.sum(torch.abs(sample))
877877
result_mean = torch.mean(torch.abs(sample))
878878

879-
assert abs(result_sum.item() - 1006.370) < 1e-2
879+
assert abs(result_sum.item() - 1006.388) < 1e-2
880880
assert abs(result_mean.item() - 1.31) < 1e-3

0 commit comments

Comments
 (0)