Skip to content

Commit 394243c

Browse files
finish pndm sampler
1 parent fe98574 commit 394243c

File tree

7 files changed

+78
-71
lines changed

7 files changed

+78
-71
lines changed

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,8 @@ def __call__(self, batch_size=1, generator=None, torch_device=None, output_type=
4848
# 1. predict noise model_output
4949
model_output = self.unet(image, t)["sample"]
5050

51-
# 2. predict previous mean of image x_t-1
52-
pred_prev_image = self.scheduler.step(model_output, t, image)["prev_sample"]
53-
54-
# 3. set current image to prev_image: x_t -> x_t-1
55-
image = pred_prev_image
51+
# 2. compute previous image: x_t -> t_t-1
52+
image = self.scheduler.step(model_output, t, image)["prev_sample"]
5653

5754
image = (image / 2 + 0.5).clamp(0, 1)
5855
image = image.cpu().permute(0, 2, 3, 1).numpy()

src/diffusers/pipelines/pndm/pipeline_pndm.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,20 @@ def __call__(self, batch_size=1, generator=None, torch_device=None, num_inferenc
4444
image = image.to(torch_device)
4545

4646
self.scheduler.set_timesteps(num_inference_steps)
47-
for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
47+
for t in tqdm(self.scheduler.timesteps):
4848
model_output = self.unet(image, t)["sample"]
4949

50-
image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"]
51-
52-
for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
53-
model_output = self.unet(image, t)["sample"]
54-
55-
image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"]
50+
image = self.scheduler.step(model_output, t, image)["prev_sample"]
51+
52+
# for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
53+
# model_output = self.unet(image, t)["sample"]
54+
#
55+
# image = self.scheduler.step_prk(model_output, t, image, i=i)["prev_sample"]
56+
#
57+
# for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
58+
# model_output = self.unet(image, t)["sample"]
59+
#
60+
# image = self.scheduler.step_plms(model_output, t, image, i=i)["prev_sample"]
5661

5762
image = (image / 2 + 0.5).clamp(0, 1)
5863
image = image.cpu().permute(0, 2, 3, 1).numpy()

src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,15 @@ def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
2828
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
2929
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
3030

31+
# correction step
3132
for _ in range(self.scheduler.correct_steps):
32-
model_output = self.model(sample, sigma_t)
33-
34-
if isinstance(model_output, dict):
35-
model_output = model_output["sample"]
36-
33+
model_output = self.model(sample, sigma_t)["sample"]
3734
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
3835

39-
with torch.no_grad():
40-
model_output = model(sample, sigma_t)
41-
42-
if isinstance(model_output, dict):
43-
model_output = model_output["sample"]
44-
36+
# prediction step
37+
model_output = model(sample, sigma_t)["sample"]
4538
output = self.scheduler.step_pred(model_output, t, sample)
39+
4640
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
4741

4842
sample = sample.clamp(0, 1)

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def step(
106106
model_output: Union[torch.FloatTensor, np.ndarray],
107107
timestep: int,
108108
sample: Union[torch.FloatTensor, np.ndarray],
109-
eta,
110-
use_clipped_model_output=False,
109+
eta: float = 0.0,
110+
use_clipped_model_output: bool = False,
111111
generator=None,
112112
):
113113
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def __init__(
5656
beta_end=0.02,
5757
beta_schedule="linear",
5858
trained_betas=None,
59-
timestep_values=None,
6059
variance_type="fixed_small",
6160
clip_sample=True,
6261
tensor_format="pt",

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
1616

1717
import math
18-
import pdb
1918
from typing import Union
2019

2120
import numpy as np
@@ -79,78 +78,91 @@ def __init__(
7978

8079
# running values
8180
self.cur_model_output = 0
81+
self.counter = 0
8282
self.cur_sample = None
8383
self.ets = []
8484

8585
# setable values
8686
self.num_inference_steps = None
87-
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
87+
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
8888
self.prk_timesteps = None
8989
self.plms_timesteps = None
90+
self.timesteps = None
9091

9192
self.tensor_format = tensor_format
9293
self.set_format(tensor_format=tensor_format)
9394

9495
def set_timesteps(self, num_inference_steps):
9596
self.num_inference_steps = num_inference_steps
96-
self.timesteps = list(
97+
self._timesteps = list(
9798
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
9899
)
99100

100-
prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile(
101+
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
101102
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
102103
)
103-
self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
104-
self.plms_timesteps = list(reversed(self.timesteps[:-3]))
104+
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
105+
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
106+
self.timesteps = self.prk_timesteps + self.plms_timesteps
105107

108+
self.counter = 0
106109
self.set_format(tensor_format=self.tensor_format)
107110

111+
def step(
112+
self,
113+
model_output: Union[torch.FloatTensor, np.ndarray],
114+
timestep: int,
115+
sample: Union[torch.FloatTensor, np.ndarray],
116+
):
117+
if self.counter < len(self.prk_timesteps):
118+
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
119+
else:
120+
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
121+
108122
def step_prk(
109123
self,
110124
model_output: Union[torch.FloatTensor, np.ndarray],
111125
timestep: int,
112126
sample: Union[torch.FloatTensor, np.ndarray],
113-
num_inference_steps,
114127
):
115128
"""
116129
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
117130
solution to the differential equation.
118131
"""
119-
t = timestep
120-
prk_time_steps = self.prk_timesteps
121-
122-
t_orig = prk_time_steps[t // 4 * 4]
123-
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
132+
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
133+
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
134+
timestep = self.prk_timesteps[self.counter // 4 * 4]
124135

125-
if t % 4 == 0:
136+
if self.counter % 4 == 0:
126137
self.cur_model_output += 1 / 6 * model_output
127138
self.ets.append(model_output)
128139
self.cur_sample = sample
129-
elif (t - 1) % 4 == 0:
140+
elif (self.counter - 1) % 4 == 0:
130141
self.cur_model_output += 1 / 3 * model_output
131-
elif (t - 2) % 4 == 0:
142+
elif (self.counter - 2) % 4 == 0:
132143
self.cur_model_output += 1 / 3 * model_output
133-
elif (t - 3) % 4 == 0:
144+
elif (self.counter - 3) % 4 == 0:
134145
model_output = self.cur_model_output + 1 / 6 * model_output
135146
self.cur_model_output = 0
136147

137148
# cur_sample should not be `None`
138149
cur_sample = self.cur_sample if self.cur_sample is not None else sample
139150

140-
return {"prev_sample": self.get_prev_sample(cur_sample, t_orig, t_orig_prev, model_output)}
151+
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
152+
self.counter += 1
153+
154+
return {"prev_sample": prev_sample}
141155

142156
def step_plms(
143157
self,
144158
model_output: Union[torch.FloatTensor, np.ndarray],
145159
timestep: int,
146160
sample: Union[torch.FloatTensor, np.ndarray],
147-
num_inference_steps,
148161
):
149162
"""
150163
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
151164
times to approximate the solution.
152165
"""
153-
t = timestep
154166
if len(self.ets) < 3:
155167
raise ValueError(
156168
f"{self.__class__} can only be run AFTER scheduler has been run "
@@ -159,17 +171,17 @@ def step_plms(
159171
"for more information."
160172
)
161173

162-
timesteps = self.plms_timesteps
163-
164-
t_orig = timesteps[t]
165-
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
174+
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
166175
self.ets.append(model_output)
167176

168177
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
169178

170-
return {"prev_sample": self.get_prev_sample(sample, t_orig, t_orig_prev, model_output)}
179+
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
180+
self.counter += 1
181+
182+
return {"prev_sample": prev_sample}
171183

172-
def get_prev_sample(self, sample, t_orig, t_orig_prev, model_output):
184+
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
173185
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
174186
# this function computes x_(t−δ) using the formula of (9)
175187
# Note that x_t needs to be added to both sides of the equation
@@ -182,8 +194,8 @@ def get_prev_sample(self, sample, t_orig, t_orig_prev, model_output):
182194
# sample -> x_t
183195
# model_output -> e_θ(x_t, t)
184196
# prev_sample -> x_(t−δ)
185-
alpha_prod_t = self.alphas_cumprod[t_orig + 1]
186-
alpha_prod_t_prev = self.alphas_cumprod[t_orig_prev + 1]
197+
alpha_prod_t = self.alphas_cumprod[timestep + 1]
198+
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1]
187199
beta_prod_t = 1 - alpha_prod_t
188200
beta_prod_t_prev = 1 - alpha_prod_t_prev
189201

tests/test_scheduler.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import pdb
1615
import tempfile
1716
import unittest
1817

@@ -383,21 +382,22 @@ def get_scheduler_config(self, **kwargs):
383382

384383
def check_over_configs(self, time_step=0, **config):
385384
kwargs = dict(self.forward_default_kwargs)
385+
num_inference_steps = kwargs.pop("num_inference_steps", None)
386386
sample = self.dummy_sample
387387
residual = 0.1 * sample
388388
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
389389

390390
for scheduler_class in self.scheduler_classes:
391391
scheduler_config = self.get_scheduler_config(**config)
392392
scheduler = scheduler_class(**scheduler_config)
393-
scheduler.set_timesteps(kwargs["num_inference_steps"])
393+
scheduler.set_timesteps(num_inference_steps)
394394
# copy over dummy past residuals
395395
scheduler.ets = dummy_past_residuals[:]
396396

397397
with tempfile.TemporaryDirectory() as tmpdirname:
398398
scheduler.save_config(tmpdirname)
399399
new_scheduler = scheduler_class.from_config(tmpdirname)
400-
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
400+
new_scheduler.set_timesteps(num_inference_steps)
401401
# copy over dummy past residuals
402402
new_scheduler.ets = dummy_past_residuals[:]
403403

@@ -416,15 +416,15 @@ def test_from_pretrained_save_pretrained(self):
416416

417417
def check_over_forward(self, time_step=0, **forward_kwargs):
418418
kwargs = dict(self.forward_default_kwargs)
419-
kwargs.update(forward_kwargs)
419+
num_inference_steps = kwargs.pop("num_inference_steps", None)
420420
sample = self.dummy_sample
421421
residual = 0.1 * sample
422422
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
423423

424424
for scheduler_class in self.scheduler_classes:
425425
scheduler_config = self.get_scheduler_config()
426426
scheduler = scheduler_class(**scheduler_config)
427-
scheduler.set_timesteps(kwargs["num_inference_steps"])
427+
scheduler.set_timesteps(num_inference_steps)
428428

429429
# copy over dummy past residuals
430430
scheduler.ets = dummy_past_residuals[:]
@@ -434,7 +434,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
434434
new_scheduler = scheduler_class.from_config(tmpdirname)
435435
# copy over dummy past residuals
436436
new_scheduler.ets = dummy_past_residuals[:]
437-
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
437+
new_scheduler.set_timesteps(num_inference_steps)
438438

439439
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
440440
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
@@ -474,12 +474,12 @@ def test_pytorch_equal_numpy(self):
474474
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
475475
kwargs["num_inference_steps"] = num_inference_steps
476476

477-
output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
478-
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
477+
output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
478+
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
479479
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
480480

481-
output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
482-
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
481+
output = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
482+
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
483483

484484
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
485485

@@ -503,14 +503,14 @@ def test_step_shape(self):
503503
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
504504
kwargs["num_inference_steps"] = num_inference_steps
505505

506-
output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
507-
output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
506+
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"]
507+
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
508508

509509
self.assertEqual(output_0.shape, sample.shape)
510510
self.assertEqual(output_0.shape, output_1.shape)
511511

512-
output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
513-
output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
512+
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs)["prev_sample"]
513+
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
514514

515515
self.assertEqual(output_0.shape, sample.shape)
516516
self.assertEqual(output_0.shape, output_1.shape)
@@ -541,7 +541,7 @@ def test_inference_plms_no_past_residuals(self):
541541
scheduler_config = self.get_scheduler_config()
542542
scheduler = scheduler_class(**scheduler_config)
543543

544-
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
544+
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample)["prev_sample"]
545545

546546
def test_full_loop_no_noise(self):
547547
scheduler_class = self.scheduler_classes[0]
@@ -555,11 +555,11 @@ def test_full_loop_no_noise(self):
555555

556556
for i, t in enumerate(scheduler.prk_timesteps):
557557
residual = model(sample, t)
558-
sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"]
558+
sample = scheduler.step_prk(residual, i, sample)["prev_sample"]
559559

560560
for i, t in enumerate(scheduler.plms_timesteps):
561561
residual = model(sample, t)
562-
sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"]
562+
sample = scheduler.step_plms(residual, i, sample)["prev_sample"]
563563

564564
result_sum = torch.sum(torch.abs(sample))
565565
result_mean = torch.mean(torch.abs(sample))
@@ -706,7 +706,7 @@ def test_full_loop_no_noise(self):
706706
model_output = model(sample, sigma_t)
707707

708708
output = scheduler.step_pred(model_output, t, sample, **kwargs)
709-
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
709+
sample, _ = output["prev_sample"], output["prev_sample_mean"]
710710

711711
result_sum = torch.sum(torch.abs(sample))
712712
result_mean = torch.mean(torch.abs(sample))

0 commit comments

Comments
 (0)