Skip to content

Commit a8a7e57

Browse files
timestep scheduling with np.linspace (#8623)
Fixes #8600 ### Description The `np.linspace` approach generates a descending array that starts exactly at 999 and ends exactly at 0 (after rounding), ensuring the scheduler samples the entire intended trajectory. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f315bcb commit a8a7e57

File tree

2 files changed

+11
-18
lines changed

2 files changed

+11
-18
lines changed

monai/networks/schedulers/ddim.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,14 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
117117
)
118118

119119
self.num_inference_steps = num_inference_steps
120-
step_ratio = self.num_train_timesteps // self.num_inference_steps
121-
if self.steps_offset >= step_ratio:
122-
raise ValueError(
123-
f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to "
124-
f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed"
125-
f" the max train timestep."
126-
)
127-
128-
# creates integer timesteps by multiplying by ratio
129-
# casting to int to avoid issues when num_inference_step is power of 3
130-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
131-
self.timesteps = torch.from_numpy(timesteps).to(device)
120+
if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps:
121+
raise ValueError(f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).")
122+
123+
self.timesteps = (
124+
torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device)
125+
.round()
126+
.long()
127+
)
132128
self.timesteps += self.steps_offset
133129

134130
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:

monai/networks/schedulers/ddpm.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
from __future__ import annotations
3333

34-
import numpy as np
3534
import torch
3635

3736
from monai.utils import StrEnum
@@ -122,11 +121,9 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
122121
)
123122

124123
self.num_inference_steps = num_inference_steps
125-
step_ratio = self.num_train_timesteps // self.num_inference_steps
126-
# creates integer timesteps by multiplying by ratio
127-
# casting to int to avoid issues when num_inference_step is power of 3
128-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
129-
self.timesteps = torch.from_numpy(timesteps).to(device)
124+
self.timesteps = (
125+
torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long()
126+
)
130127

131128
def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
132129
"""

0 commit comments

Comments
 (0)