Skip to content

Commit bb12328

Browse files
committed
ddpm custom timesteps
add custom timesteps test add custom timesteps descending order check docs timesteps -> custom_timesteps can only pass one of num_inference_steps and timesteps
1 parent eb2ef31 commit bb12328

File tree

2 files changed

+119
-16
lines changed

2 files changed

+119
-16
lines changed

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def __init__(
162162
self.init_noise_sigma = 1.0
163163

164164
# setable values
165+
self.custom_timesteps = False
165166
self.num_inference_steps = None
166167
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
167168

@@ -191,31 +192,62 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] =
191192
"""
192193
return sample
193194

194-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
195+
def set_timesteps(
196+
self,
197+
num_inference_steps: Optional[int] = None,
198+
device: Union[str, torch.device] = None,
199+
timesteps: Optional[List[int]] = None,
200+
):
195201
"""
196202
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
197203
198204
Args:
199-
num_inference_steps (`int`):
200-
the number of diffusion steps used when generating samples with a pre-trained model.
205+
num_inference_steps (`Optional[int]`):
206+
the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
207+
`timesteps` must be `None`.
208+
device (`str` or `torch.device`, optional):
209+
the device to which the timesteps are moved to.
210+
custom_timesteps (`List[int]`, optional):
211+
custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
212+
timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
213+
must be `None`.
214+
201215
"""
216+
if num_inference_steps is not None and timesteps is not None:
217+
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
218+
219+
if timesteps is not None:
220+
for i in range(1, len(timesteps)):
221+
if timesteps[i] >= timesteps[i - 1]:
222+
raise ValueError("`custom_timesteps` must be in descending order.")
223+
224+
if timesteps[0] >= self.config.num_train_timesteps:
225+
raise ValueError(
226+
f"`timesteps` must start before `self.config.train_timesteps`:"
227+
f" {self.config.num_train_timesteps}."
228+
)
229+
230+
timesteps = np.array(timesteps, dtype=np.int64)
231+
self.custom_timesteps = True
232+
else:
233+
if num_inference_steps > self.config.num_train_timesteps:
234+
raise ValueError(
235+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
236+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
237+
f" maximal {self.config.num_train_timesteps} timesteps."
238+
)
202239

203-
if num_inference_steps > self.config.num_train_timesteps:
204-
raise ValueError(
205-
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
206-
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
207-
f" maximal {self.config.num_train_timesteps} timesteps."
208-
)
240+
self.num_inference_steps = num_inference_steps
209241

210-
self.num_inference_steps = num_inference_steps
242+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
243+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
244+
self.custom_timesteps = False
211245

212-
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
213-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
214246
self.timesteps = torch.from_numpy(timesteps).to(device)
215247

216248
def _get_variance(self, t, predicted_variance=None, variance_type=None):
217-
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
218-
prev_t = t - self.config.num_train_timesteps // num_inference_steps
249+
prev_t = self.previous_timestep(t)
250+
219251
alpha_prod_t = self.alphas_cumprod[t]
220252
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
221253
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
@@ -314,8 +346,8 @@ def step(
314346
315347
"""
316348
t = timestep
317-
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
318-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
349+
350+
prev_t = self.previous_timestep(t)
319351

320352
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
321353
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
@@ -428,3 +460,18 @@ def get_velocity(
428460

429461
def __len__(self):
430462
return self.config.num_train_timesteps
463+
464+
def previous_timestep(self, timestep):
465+
if self.custom_timesteps:
466+
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
467+
if index == self.timesteps.shape[0] - 1:
468+
prev_t = torch.tensor(-1)
469+
else:
470+
prev_t = self.timesteps[index + 1]
471+
else:
472+
num_inference_steps = (
473+
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
474+
)
475+
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
476+
477+
return prev_t

tests/schedulers/test_scheduler_ddpm.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,59 @@ def test_full_loop_with_v_prediction(self):
129129

130130
assert abs(result_sum.item() - 202.0296) < 1e-2
131131
assert abs(result_mean.item() - 0.2631) < 1e-3
132+
133+
def test_custom_timesteps(self):
134+
scheduler_class = self.scheduler_classes[0]
135+
scheduler_config = self.get_scheduler_config()
136+
scheduler = scheduler_class(**scheduler_config)
137+
138+
timesteps = [100, 87, 50, 1, 0]
139+
140+
scheduler.set_timesteps(timesteps=timesteps)
141+
142+
scheduler_timesteps = scheduler.timesteps
143+
144+
for i, timestep in enumerate(scheduler_timesteps):
145+
if i == len(timesteps) - 1:
146+
expected_prev_t = -1
147+
else:
148+
expected_prev_t = timesteps[i + 1]
149+
150+
prev_t = scheduler.previous_timestep(timestep)
151+
prev_t = prev_t.item()
152+
153+
self.assertEqual(prev_t, expected_prev_t)
154+
155+
def test_custom_timesteps_increasing_order(self):
156+
scheduler_class = self.scheduler_classes[0]
157+
scheduler_config = self.get_scheduler_config()
158+
scheduler = scheduler_class(**scheduler_config)
159+
160+
timesteps = [100, 87, 50, 51, 0]
161+
162+
with self.assertRaises(ValueError, msg="`custom_timesteps` must be in descending order."):
163+
scheduler.set_timesteps(timesteps=timesteps)
164+
165+
def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self):
166+
scheduler_class = self.scheduler_classes[0]
167+
scheduler_config = self.get_scheduler_config()
168+
scheduler = scheduler_class(**scheduler_config)
169+
170+
timesteps = [100, 87, 50, 1, 0]
171+
num_inference_steps = len(timesteps)
172+
173+
with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `custom_timesteps`."):
174+
scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps)
175+
176+
def test_custom_timesteps_too_large(self):
177+
scheduler_class = self.scheduler_classes[0]
178+
scheduler_config = self.get_scheduler_config()
179+
scheduler = scheduler_class(**scheduler_config)
180+
181+
timesteps = [scheduler.config.num_train_timesteps]
182+
183+
with self.assertRaises(
184+
ValueError,
185+
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
186+
):
187+
scheduler.set_timesteps(timesteps=timesteps)

0 commit comments

Comments
 (0)