Skip to content

Commit f3c48d9

Browse files
committed
ddpm custom timesteps
add custom timesteps test add custom timesteps descending order check docs
1 parent d0f2582 commit f3c48d9

File tree

2 files changed

+80
-7
lines changed

2 files changed

+80
-7
lines changed

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def __init__(
162162
self.init_noise_sigma = 1.0
163163

164164
# setable values
165+
self.custom_timesteps = False
165166
self.num_inference_steps = None
166167
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
167168

@@ -191,14 +192,31 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] =
191192
"""
192193
return sample
193194

194-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
195+
def set_timesteps(
196+
self,
197+
num_inference_steps: int,
198+
device: Union[str, torch.device] = None,
199+
custom_timesteps: Optional[List[int]] = None,
200+
):
195201
"""
196202
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
197203
198204
Args:
199205
num_inference_steps (`int`):
200206
the number of diffusion steps used when generating samples with a pre-trained model.
207+
device (`str` or `torch.device`, optional):
208+
the device to which the timesteps are moved to.
209+
custom_timesteps (`List[int]`, optional):
210+
custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
211+
timestep spacing strategy of equal spacing between timesteps is used.
212+
201213
"""
214+
if custom_timesteps is not None:
215+
num_inference_steps = len(custom_timesteps)
216+
217+
for i in range(1, len(custom_timesteps)):
218+
if custom_timesteps[i] >= custom_timesteps[i - 1]:
219+
raise ValueError("`custom_timesteps` must be in descending order.")
202220

203221
if num_inference_steps > self.config.num_train_timesteps:
204222
raise ValueError(
@@ -209,13 +227,19 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
209227

210228
self.num_inference_steps = num_inference_steps
211229

212-
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
213-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
230+
if custom_timesteps is None:
231+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
232+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
233+
self.custom_timesteps = False
234+
else:
235+
timesteps = np.array(custom_timesteps, dtype=np.int64)
236+
self.custom_timesteps = True
237+
214238
self.timesteps = torch.from_numpy(timesteps).to(device)
215239

216240
def _get_variance(self, t, predicted_variance=None, variance_type=None):
217-
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
218-
prev_t = t - self.config.num_train_timesteps // num_inference_steps
241+
prev_t = self.previous_timestep(t)
242+
219243
alpha_prod_t = self.alphas_cumprod[t]
220244
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
221245
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
@@ -314,8 +338,8 @@ def step(
314338
315339
"""
316340
t = timestep
317-
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
318-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
341+
342+
prev_t = self.previous_timestep(t)
319343

320344
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
321345
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
@@ -428,3 +452,18 @@ def get_velocity(
428452

429453
def __len__(self):
430454
return self.config.num_train_timesteps
455+
456+
def previous_timestep(self, timestep):
457+
if self.custom_timesteps:
458+
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
459+
if index == self.timesteps.shape[0] - 1:
460+
prev_t = torch.tensor(-1)
461+
else:
462+
prev_t = self.timesteps[index + 1]
463+
else:
464+
num_inference_steps = (
465+
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
466+
)
467+
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
468+
469+
return prev_t

tests/schedulers/test_scheduler_ddpm.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,37 @@ def test_full_loop_with_v_prediction(self):
129129

130130
assert abs(result_sum.item() - 202.0296) < 1e-2
131131
assert abs(result_mean.item() - 0.2631) < 1e-3
132+
133+
def test_custom_timesteps(self):
134+
scheduler_class = self.scheduler_classes[0]
135+
scheduler_config = self.get_scheduler_config()
136+
scheduler = scheduler_class(**scheduler_config)
137+
138+
timesteps = [100, 87, 50, 1, 0]
139+
num_inference_steps = len(timesteps)
140+
141+
scheduler.set_timesteps(num_inference_steps=num_inference_steps, custom_timesteps=timesteps)
142+
143+
scheduler_timesteps = scheduler.timesteps
144+
145+
for i, timestep in enumerate(scheduler_timesteps):
146+
if i == num_inference_steps - 1:
147+
expected_prev_t = -1
148+
else:
149+
expected_prev_t = timesteps[i + 1]
150+
151+
prev_t = scheduler.previous_timestep(timestep)
152+
prev_t = prev_t.item()
153+
154+
self.assertEqual(prev_t, expected_prev_t)
155+
156+
def test_custom_timesteps_increasing_order(self):
157+
scheduler_class = self.scheduler_classes[0]
158+
scheduler_config = self.get_scheduler_config()
159+
scheduler = scheduler_class(**scheduler_config)
160+
161+
timesteps = [100, 87, 50, 51, 0]
162+
num_inference_steps = len(timesteps)
163+
164+
with self.assertRaises(ValueError, msg="`custom_timesteps` must be in descending order."):
165+
scheduler.set_timesteps(num_inference_steps=num_inference_steps, custom_timesteps=timesteps)

0 commit comments

Comments
 (0)