Skip to content

Commit 074d281

Browse files
committed
tests and additional scheduler fixes
1 parent 953c9d1 commit 074d281

File tree

4 files changed

+36
-3
lines changed

4 files changed

+36
-3
lines changed

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def __init__(
171171
self.model_outputs = [None] * solver_order
172172
self.lower_order_nums = 0
173173

174+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_timesteps
174175
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
175176
"""
176177
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -181,14 +182,22 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
181182
device (`str` or `torch.device`, optional):
182183
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
183184
"""
184-
self.num_inference_steps = num_inference_steps
185185
timesteps = (
186186
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
187187
.round()[::-1][:-1]
188188
.copy()
189189
.astype(np.int64)
190190
)
191+
192+
# when num_inference_steps == num_train_timesteps, we can end up with
193+
# duplicates in timesteps.
194+
_, unique_indices = np.unique(timesteps, return_index=True)
195+
timesteps = timesteps[np.sort(unique_indices)]
196+
191197
self.timesteps = torch.from_numpy(timesteps).to(device)
198+
199+
self.num_inference_steps = len(timesteps)
200+
192201
self.model_outputs = [
193202
None,
194203
] * self.config.solver_order

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,21 +194,29 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
194194
device (`str` or `torch.device`, optional):
195195
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
196196
"""
197-
self.num_inference_steps = num_inference_steps
198197
timesteps = (
199198
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
200199
.round()[::-1][:-1]
201200
.copy()
202201
.astype(np.int64)
203202
)
203+
204+
# when num_inference_steps == num_train_timesteps, we can end up with
205+
# duplicates in timesteps.
206+
_, unique_indices = np.unique(timesteps, return_index=True)
207+
timesteps = timesteps[np.sort(unique_indices)]
208+
204209
self.timesteps = torch.from_numpy(timesteps).to(device)
210+
211+
self.num_inference_steps = len(timesteps)
212+
205213
self.model_outputs = [
206214
None,
207215
] * self.config.solver_order
208216
self.lower_order_nums = 0
209217
self.last_sample = None
210218
if self.solver_p:
211-
self.solver_p.set_timesteps(num_inference_steps, device=device)
219+
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
212220

213221
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
214222
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:

tests/schedulers/test_scheduler_dpm_multi.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,11 @@ def test_fp16_support(self):
243243
sample = scheduler.step(residual, t, sample).prev_sample
244244

245245
assert sample.dtype == torch.float16
246+
247+
def test_unique_timesteps(self, **config):
248+
for scheduler_class in self.scheduler_classes:
249+
scheduler_config = self.get_scheduler_config(**config)
250+
scheduler = scheduler_class(**scheduler_config)
251+
252+
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
253+
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps

tests/schedulers/test_scheduler_unipc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,11 @@ def test_fp16_support(self):
229229
sample = scheduler.step(residual, t, sample).prev_sample
230230

231231
assert sample.dtype == torch.float16
232+
233+
def test_unique_timesteps(self, **config):
234+
for scheduler_class in self.scheduler_classes:
235+
scheduler_config = self.get_scheduler_config(**config)
236+
scheduler = scheduler_class(**scheduler_config)
237+
238+
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
239+
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps

0 commit comments

Comments
 (0)