Skip to content

Commit 870878d

Browse files
committed
Fix DDIM on Windows not using int64 for timesteps
1 parent 323a9e1 commit 870878d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def __init__(
157157

158158
# setable values
159159
self.num_inference_steps = None
160-
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
160+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
161161

162162
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
163163
"""
@@ -200,7 +200,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
200200
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
201201
# creates integer timesteps by multiplying by ratio
202202
# casting to int to avoid issues when num_inference_step is power of 3
203-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
203+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
204204
self.timesteps = torch.from_numpy(timesteps).to(device)
205205
self.timesteps += offset
206206

0 commit comments

Comments
 (0)