From b96ec40ae12b403432a9d0855eb799df6dde7a10 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 5 Nov 2025 16:50:21 +0800 Subject: [PATCH 1/9] timestep scheduling with np.linspace Signed-off-by: ytl0623 --- monai/networks/schedulers/ddim.py | 2 +- monai/networks/schedulers/ddpm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index 50a680336d..993b826727 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -127,7 +127,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps).round().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps += self.steps_offset diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index e2b7ab55f5..7ef27108bc 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -125,7 +125,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N step_ratio = self.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) + timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps).round().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: From 3a2ffc51a261cde49bee3ce6c11adb6d5d96ed0a Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 5 Nov 2025 16:58:07 +0800 Subject: [PATCH 2/9] remove step_ratio variable Signed-off-by: ytl0623 --- monai/networks/schedulers/ddpm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index 7ef27108bc..76dea2bb83 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -122,7 +122,6 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N ) self.num_inference_steps = num_inference_steps - step_ratio = self.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps).round().astype(np.int64) From aa3cce606282120d3512a5c12f59f8def3ea7ad3 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 5 Nov 2025 17:01:50 +0800 Subject: [PATCH 3/9] steps_offset causes out-of-bounds timestep indices Signed-off-by: ytl0623 --- monai/networks/schedulers/ddim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index 993b826727..cacc663c33 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -127,7 +127,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps).round().astype(np.int64) + timesteps = np.linspace(num_train_timesteps - 1 - steps_offset, 0, num_inference_steps).round().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps += self.steps_offset From 616f5f2349c86030e4607b8d2e77f02cec9a3f0f Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 5 Nov 2025 18:20:08 +0800 Subject: [PATCH 4/9] Refactor timesteps calculation in ddim.py Signed-off-by: ytl0623 --- monai/networks/schedulers/ddim.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index cacc663c33..692ef90872 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -125,9 +125,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N f" the max train timestep." ) - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.linspace(num_train_timesteps - 1 - steps_offset, 0, num_inference_steps).round().astype(np.int64) + timesteps = np.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps).round().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps += self.steps_offset From f7ca165138d006d3a13fcfe3c1a9d379738cb54d Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 5 Nov 2025 18:20:56 +0800 Subject: [PATCH 5/9] Fix timesteps calculation in ddpm.py Signed-off-by: ytl0623 --- monai/networks/schedulers/ddpm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index 76dea2bb83..1b91398d4e 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -122,9 +122,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N ) self.num_inference_steps = num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps).round().astype(np.int64) + timesteps = np.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps).round().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: From ef7e83cbf5c74e7cd8841164b59ebb16b57b6752 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Thu, 6 Nov 2025 08:52:24 +0800 Subject: [PATCH 6/9] With the linspace approach, max_timestep = (num_train_timesteps - 1 - steps_offset) + steps_offset = num_train_timesteps - 1 regardless of the relationship between steps_offset and step_ratio. The actual constraint is 0 <= steps_offset < num_train_timesteps. Signed-off-by: ytl0623 --- monai/networks/schedulers/ddim.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index 692ef90872..9e2c5c39e3 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -117,15 +117,14 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N ) self.num_inference_steps = num_inference_steps - step_ratio = self.num_train_timesteps // self.num_inference_steps - if self.steps_offset >= step_ratio: - raise ValueError( - f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to " - f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed" - f" the max train timestep." - ) - - timesteps = np.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps).round().astype(np.int64) + if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps: + raise ValueError(f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).") + + timesteps = ( + np.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps) + .round() + .astype(np.int64) + ) self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps += self.steps_offset From 7b71a61efc8e7815a059aaf5544eaf824f0b9ba3 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 11 Nov 2025 09:01:58 +0800 Subject: [PATCH 7/9] Update monai/networks/schedulers/ddim.py Signed-off-by: ytl0623 --- monai/networks/schedulers/ddim.py | 7 +------ monai/networks/schedulers/ddpm.py | 4 ++-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index 9e2c5c39e3..acff336f61 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -120,12 +120,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps: raise ValueError(f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).") - timesteps = ( - np.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps) - .round() - .astype(np.int64) - ) - self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps = torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device).round().long() self.timesteps += self.steps_offset def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index 1b91398d4e..8934a230c2 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -122,8 +122,8 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N ) self.num_inference_steps = num_inference_steps - timesteps = np.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps).round().astype(np.int64) - self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps = torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long() + def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: """ From 84e917f8099e0f1e0efdca45bce14687daf25d5a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Nov 2025 01:02:52 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/schedulers/ddpm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index 8934a230c2..0507d284e2 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -31,7 +31,6 @@ from __future__ import annotations -import numpy as np import torch from monai.utils import StrEnum @@ -123,7 +122,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N self.num_inference_steps = num_inference_steps self.timesteps = torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long() - + def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: """ From 91efbc8877c10d1959cf43a94f5e175788d42b4d Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 11 Nov 2025 09:04:02 +0800 Subject: [PATCH 9/9] Update monai/networks/schedulers/ddim.py Signed-off-by: ytl0623 --- monai/networks/schedulers/ddim.py | 6 +++++- monai/networks/schedulers/ddpm.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index acff336f61..9d843c6898 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -120,7 +120,11 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps: raise ValueError(f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).") - self.timesteps = torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device).round().long() + self.timesteps = ( + torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device) + .round() + .long() + ) self.timesteps += self.steps_offset def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index 8934a230c2..61c037bb36 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -122,8 +122,9 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N ) self.num_inference_steps = num_inference_steps - self.timesteps = torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long() - + self.timesteps = ( + torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long() + ) def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: """