Skip to content

Commit 91efbc8

Browse files
committed
Update monai/networks/schedulers/ddim.py
Signed-off-by: ytl0623 <[email protected]>
1 parent 7b71a61 commit 91efbc8

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

monai/networks/schedulers/ddim.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
120120
if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps:
121121
raise ValueError(f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).")
122122

123-
self.timesteps = torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device).round().long()
123+
self.timesteps = (
124+
torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device)
125+
.round()
126+
.long()
127+
)
124128
self.timesteps += self.steps_offset
125129

126130
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:

monai/networks/schedulers/ddpm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
122122
)
123123

124124
self.num_inference_steps = num_inference_steps
125-
self.timesteps = torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long()
126-
125+
self.timesteps = (
126+
torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long()
127+
)
127128

128129
def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
129130
"""

0 commit comments

Comments
 (0)