diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index c8b1f2c3bedf..b8205455d6d9 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -201,7 +201,7 @@ def set_timesteps( else: timesteps = torch.from_numpy(timesteps).to(device) - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) + timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 809da798f889..b49cc2e54412 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -190,7 +190,7 @@ def set_timesteps( timesteps = torch.from_numpy(timesteps).to(device) # interpolate timesteps - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) + timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])