Skip to content

Commit f4781a0

Browse files
update expected results of slow tests (#268)
* update expected results of slow tests * relax sum and mean tests * Print shapes when reporting exception * formatting * fix sentence * relax test_stable_diffusion_fast_ddim for gpu fp16 * relax flakey tests on GPU * added comment on large tolerences * black * format * set scheduler seed * added generator * use np.isclose * set num_inference_steps to 50 * fix dep. warning * update expected_slice * preprocess if image * updated expected results * updated expected from CI * pass generator to VAE * undo change back to orig * use orignal * revert back the expected on cpu * revert back values for CPU * more undo * update result after using gen * update mean * set generator for mps * update expected on CI server * undo * use new seed every time * cpu manual seed * reduce num_inference_steps * style * use generator for randn Co-authored-by: Patrick von Platen <[email protected]>
1 parent 25a51b6 commit f4781a0

File tree

7 files changed

+75
-47
lines changed

7 files changed

+75
-47
lines changed

src/diffusers/models/vae.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,11 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
557557
return DecoderOutput(sample=dec)
558558

559559
def forward(
560-
self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
560+
self,
561+
sample: torch.FloatTensor,
562+
sample_posterior: bool = False,
563+
return_dict: bool = True,
564+
generator: Optional[torch.Generator] = None,
561565
) -> Union[DecoderOutput, torch.FloatTensor]:
562566
r"""
563567
Args:
@@ -570,7 +574,7 @@ def forward(
570574
x = sample
571575
posterior = self.encode(x).latent_dist
572576
if sample_posterior:
573-
z = posterior.sample()
577+
z = posterior.sample(generator=generator)
574578
else:
575579
z = posterior.mode()
576580
dec = self.decode(z).sample

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __call__(
178178

179179
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
180180

181-
if not isinstance(init_image, torch.FloatTensor):
181+
if isinstance(init_image, PIL.Image.Image):
182182
init_image = preprocess(init_image)
183183

184184
# encode the init image into latents and scale the latents

tests/test_models_unet.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,13 @@ def test_output_pretrained(self):
138138
model.eval()
139139
model.to(torch_device)
140140

141-
torch.manual_seed(0)
142-
if torch.cuda.is_available():
143-
torch.cuda.manual_seed_all(0)
144-
145-
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
141+
noise = torch.randn(
142+
1,
143+
model.config.in_channels,
144+
model.config.sample_size,
145+
model.config.sample_size,
146+
generator=torch.manual_seed(0),
147+
)
146148
noise = noise.to(torch_device)
147149
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
148150

@@ -154,7 +156,7 @@ def test_output_pretrained(self):
154156
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
155157
# fmt: on
156158

157-
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
159+
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
158160

159161

160162
# TODO(Patrick) - Re-add this test after having cleaned up LDM

tests/test_models_vae.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,24 @@ def test_output_pretrained(self):
8787
image = image.to(torch_device)
8888
with torch.no_grad():
8989
_ = model(image, sample_posterior=True).sample
90-
91-
torch.manual_seed(0)
92-
if torch.cuda.is_available():
93-
torch.cuda.manual_seed_all(0)
94-
95-
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
90+
generator = torch.manual_seed(0)
91+
else:
92+
generator = torch.Generator(device=torch_device).manual_seed(0)
93+
94+
image = torch.randn(
95+
1,
96+
model.config.in_channels,
97+
model.config.sample_size,
98+
model.config.sample_size,
99+
generator=torch.manual_seed(0),
100+
)
96101
image = image.to(torch_device)
97102
with torch.no_grad():
98-
output = model(image, sample_posterior=True).sample
103+
output = model(image, sample_posterior=True, generator=generator).sample
99104

100105
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
106+
101107
# fmt: off
102-
expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03])
108+
expected_output_slice = torch.tensor([-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026])
103109
# fmt: on
104110
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))

tests/test_models_vq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,4 @@ def test_output_pretrained(self):
9494
# fmt: off
9595
expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
9696
# fmt: on
97-
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
97+
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))

tests/test_pipelines.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def test_stable_diffusion_ddim(self):
330330

331331
assert image.shape == (1, 128, 128, 3)
332332
expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755])
333+
333334
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
334335
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
335336

@@ -463,17 +464,18 @@ def test_score_sde_ve_pipeline(self):
463464
sde_ve.to(torch_device)
464465
sde_ve.set_progress_bar_config(disable=None)
465466

466-
torch.manual_seed(0)
467-
image = sde_ve(num_inference_steps=2, output_type="numpy").images
467+
generator = torch.manual_seed(0)
468+
image = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator).images
468469

469-
torch.manual_seed(0)
470-
image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", return_dict=False)[0]
470+
generator = torch.manual_seed(0)
471+
image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator, return_dict=False)[
472+
0
473+
]
471474

472475
image_slice = image[0, -3:, -3:, -1]
473476
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
474477

475478
assert image.shape == (1, 32, 32, 3)
476-
477479
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
478480
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
479481
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@@ -647,7 +649,7 @@ def test_stable_diffusion_inpaint(self):
647649
bert = self.dummy_text_encoder
648650
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
649651

650-
image = self.dummy_image.to(device).permute(0, 2, 3, 1)[0]
652+
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
651653
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
652654
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
653655

@@ -729,8 +731,8 @@ def test_from_pretrained_save_pretrained(self):
729731
new_ddpm.to(torch_device)
730732

731733
generator = torch.manual_seed(0)
732-
733734
image = ddpm(generator=generator, output_type="numpy").images
735+
734736
generator = generator.manual_seed(0)
735737
new_image = new_ddpm(generator=generator, output_type="numpy").images
736738

@@ -750,8 +752,8 @@ def test_from_pretrained_hub(self):
750752
ddpm_from_hub.set_progress_bar_config(disable=None)
751753

752754
generator = torch.manual_seed(0)
753-
754755
image = ddpm(generator=generator, output_type="numpy").images
756+
755757
generator = generator.manual_seed(0)
756758
new_image = ddpm_from_hub(generator=generator, output_type="numpy").images
757759

@@ -774,8 +776,8 @@ def test_from_pretrained_hub_pass_model(self):
774776
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
775777

776778
generator = torch.manual_seed(0)
777-
778779
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images
780+
779781
generator = generator.manual_seed(0)
780782
new_image = ddpm_from_hub(generator=generator, output_type="numpy").images
781783

@@ -981,14 +983,14 @@ def test_score_sde_ve_pipeline(self):
981983
sde_ve.to(torch_device)
982984
sde_ve.set_progress_bar_config(disable=None)
983985

984-
torch.manual_seed(0)
985-
image = sde_ve(num_inference_steps=300, output_type="numpy").images
986+
generator = torch.manual_seed(0)
987+
image = sde_ve(num_inference_steps=10, output_type="numpy", generator=generator).images
986988

987989
image_slice = image[0, -3:, -3:, -1]
988990

989991
assert image.shape == (1, 256, 256, 3)
990992

991-
expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
993+
expected_slice = np.array([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0])
992994
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
993995

994996
@slow

tests/test_scheduler.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -318,13 +318,14 @@ def test_full_loop_no_noise(self):
318318

319319
model = self.dummy_model()
320320
sample = self.dummy_sample_deter
321+
generator = torch.manual_seed(0)
321322

322323
for t in reversed(range(num_trained_timesteps)):
323324
# 1. predict noise residual
324325
residual = model(sample, t)
325326

326327
# 2. predict previous mean of sample x_t-1
327-
pred_prev_sample = scheduler.step(residual, t, sample).prev_sample
328+
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
328329

329330
# if t > 0:
330331
# noise = self.dummy_sample_deter
@@ -336,7 +337,7 @@ def test_full_loop_no_noise(self):
336337
result_sum = torch.sum(torch.abs(sample))
337338
result_mean = torch.mean(torch.abs(sample))
338339

339-
assert abs(result_sum.item() - 259.0883) < 1e-2
340+
assert abs(result_sum.item() - 258.9070) < 1e-2
340341
assert abs(result_mean.item() - 0.3374) < 1e-3
341342

342343

@@ -657,7 +658,7 @@ def test_full_loop_no_noise(self):
657658
class ScoreSdeVeSchedulerTest(unittest.TestCase):
658659
# TODO adapt with class SchedulerCommonTest (scheduler needs Numpy Integration)
659660
scheduler_classes = (ScoreSdeVeScheduler,)
660-
forward_default_kwargs = (("seed", 0),)
661+
forward_default_kwargs = ()
661662

662663
@property
663664
def dummy_sample(self):
@@ -718,13 +719,19 @@ def check_over_configs(self, time_step=0, **config):
718719
scheduler.save_config(tmpdirname)
719720
new_scheduler = scheduler_class.from_config(tmpdirname)
720721

721-
output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
722-
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
722+
output = scheduler.step_pred(
723+
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
724+
).prev_sample
725+
new_output = new_scheduler.step_pred(
726+
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
727+
).prev_sample
723728

724729
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
725730

726-
output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
727-
new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
731+
output = scheduler.step_correct(residual, sample, generator=torch.manual_seed(0), **kwargs).prev_sample
732+
new_output = new_scheduler.step_correct(
733+
residual, sample, generator=torch.manual_seed(0), **kwargs
734+
).prev_sample
728735

729736
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
730737

@@ -743,13 +750,19 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
743750
scheduler.save_config(tmpdirname)
744751
new_scheduler = scheduler_class.from_config(tmpdirname)
745752

746-
output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
747-
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
753+
output = scheduler.step_pred(
754+
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
755+
).prev_sample
756+
new_output = new_scheduler.step_pred(
757+
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
758+
).prev_sample
748759

749760
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
750761

751-
output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
752-
new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
762+
output = scheduler.step_correct(residual, sample, generator=torch.manual_seed(0), **kwargs).prev_sample
763+
new_output = new_scheduler.step_correct(
764+
residual, sample, generator=torch.manual_seed(0), **kwargs
765+
).prev_sample
753766

754767
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
755768

@@ -779,26 +792,27 @@ def test_full_loop_no_noise(self):
779792

780793
scheduler.set_sigmas(num_inference_steps)
781794
scheduler.set_timesteps(num_inference_steps)
795+
generator = torch.manual_seed(0)
782796

783797
for i, t in enumerate(scheduler.timesteps):
784798
sigma_t = scheduler.sigmas[i]
785799

786800
for _ in range(scheduler.correct_steps):
787801
with torch.no_grad():
788802
model_output = model(sample, sigma_t)
789-
sample = scheduler.step_correct(model_output, sample, **kwargs).prev_sample
803+
sample = scheduler.step_correct(model_output, sample, generator=generator, **kwargs).prev_sample
790804

791805
with torch.no_grad():
792806
model_output = model(sample, sigma_t)
793807

794-
output = scheduler.step_pred(model_output, t, sample, **kwargs)
808+
output = scheduler.step_pred(model_output, t, sample, generator=generator, **kwargs)
795809
sample, _ = output.prev_sample, output.prev_sample_mean
796810

797811
result_sum = torch.sum(torch.abs(sample))
798812
result_mean = torch.mean(torch.abs(sample))
799813

800-
assert abs(result_sum.item() - 14379591680.0) < 1e-2
801-
assert abs(result_mean.item() - 18723426.0) < 1e-3
814+
assert np.isclose(result_sum.item(), 14372758528.0)
815+
assert np.isclose(result_mean.item(), 18714530.0)
802816

803817
def test_step_shape(self):
804818
kwargs = dict(self.forward_default_kwargs)
@@ -817,8 +831,8 @@ def test_step_shape(self):
817831
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
818832
kwargs["num_inference_steps"] = num_inference_steps
819833

820-
output_0 = scheduler.step_pred(residual, 0, sample, **kwargs).prev_sample
821-
output_1 = scheduler.step_pred(residual, 1, sample, **kwargs).prev_sample
834+
output_0 = scheduler.step_pred(residual, 0, sample, generator=torch.manual_seed(0), **kwargs).prev_sample
835+
output_1 = scheduler.step_pred(residual, 1, sample, generator=torch.manual_seed(0), **kwargs).prev_sample
822836

823837
self.assertEqual(output_0.shape, sample.shape)
824838
self.assertEqual(output_0.shape, output_1.shape)

0 commit comments

Comments
 (0)