Skip to content

Commit 84d7fae

Browse files
authored
Fix support for MPS in KDPM2AncestralDiscreteScheduler (#6365)
Fix support for MPS MPS doesn't support float64
1 parent 4c483de commit 84d7fae

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,11 @@ def set_timesteps(
277277
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
278278
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
279279

280-
timesteps = torch.from_numpy(timesteps).to(device)
280+
if str(device).startswith("mps"):
281+
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
282+
else:
283+
timesteps = torch.from_numpy(timesteps).to(device)
284+
281285
sigmas_interpol = sigmas_interpol.cpu()
282286
log_sigmas = self.log_sigmas.cpu()
283287
timesteps_interpol = np.array(

0 commit comments

Comments
 (0)