Skip to content

Commit f300d05

Browse files
committed
fix reproducible initial noise
1 parent 86d4c5a commit f300d05

File tree

3 files changed

+29
-23
lines changed

3 files changed

+29
-23
lines changed

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,13 @@ def __call__(
8989
generator = None
9090

9191
# Sample gaussian noise to begin loop
92-
image = torch.randn(
93-
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
94-
generator=generator,
95-
device=self.device,
96-
)
92+
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
93+
if self.device.type == "mps":
94+
# randn does not work reproducibly on mps
95+
image = torch.randn(image_shape, generator=generator)
96+
image = image.to(device)
97+
else:
98+
image = torch.randn(image_shape, generator=generator, device=self.device)
9799

98100
# set step values
99101
self.scheduler.set_timesteps(num_inference_steps)

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,13 @@ def __call__(
9494
generator = None
9595

9696
# Sample gaussian noise to begin loop
97-
image = torch.randn(
98-
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
99-
generator=generator,
100-
device=self.device,
101-
)
97+
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
98+
if self.device.type == "mps":
99+
# randn does not work reproducibly on mps
100+
image = torch.randn(image_shape, generator=generator)
101+
image = image.to(device)
102+
else:
103+
image = torch.randn(image_shape, generator=generator, device=self.device)
102104

103105
# set step values
104106
self.scheduler.set_timesteps(num_inference_steps)

tests/pipelines/ddpm/test_ddpm.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
2222
from diffusers.utils import deprecate
23-
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
23+
from diffusers.utils.testing_utils import require_torch, slow, torch_device
2424

2525
from ...test_pipelines_common import PipelineTesterMixin
2626

@@ -44,18 +44,21 @@ def dummy_uncond_unet(self):
4444
return model
4545

4646
def test_inference(self):
47-
device = "cpu"
4847
unet = self.dummy_uncond_unet
4948
scheduler = DDPMScheduler()
5049

5150
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
52-
ddpm.to(device)
51+
ddpm.to(torch_device)
5352
ddpm.set_progress_bar_config(disable=None)
5453

55-
generator = torch.Generator(device=device).manual_seed(0)
54+
# Warmup pass when using mps (see #372)
55+
if torch_device == "mps":
56+
_ = ddpm(num_inference_steps=1)
57+
58+
generator = torch.manual_seed(0)
5659
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
5760

58-
generator = torch.Generator(device=device).manual_seed(0)
61+
generator = torch.manual_seed(0)
5962
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
6063

6164
image_slice = image[0, -3:, -3:, -1]
@@ -65,8 +68,9 @@ def test_inference(self):
6568
expected_slice = np.array(
6669
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
6770
)
68-
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
69-
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
71+
tolerance = 1e-2 if torch_device != "mps" else 3e-2
72+
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
73+
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
7074

7175
def test_inference_predict_epsilon(self):
7276
deprecate("remove this test", "0.10.0", "remove")
@@ -80,16 +84,14 @@ def test_inference_predict_epsilon(self):
8084
# Warmup pass when using mps (see #372)
8185
if torch_device == "mps":
8286
_ = ddpm(num_inference_steps=1)
87+
8388
if torch_device == "mps":
8489
# device type MPS is not supported for torch.Generator() api.
8590
generator = torch.manual_seed(0)
8691
else:
8792
generator = torch.Generator(device=torch_device).manual_seed(0)
8893
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
8994

90-
# Warmup pass when using mps (see #372)
91-
if torch_device == "mps":
92-
_ = ddpm(num_inference_steps=1)
9395
generator = generator.manual_seed(0)
9496
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
9597

@@ -102,7 +104,7 @@ def test_inference_predict_epsilon(self):
102104

103105

104106
@slow
105-
@require_torch_gpu
107+
@require_torch
106108
class DDPMPipelineIntegrationTests(unittest.TestCase):
107109
def test_inference_cifar10(self):
108110
model_id = "google/ddpm-cifar10-32"
@@ -114,11 +116,11 @@ def test_inference_cifar10(self):
114116
ddpm.to(torch_device)
115117
ddpm.set_progress_bar_config(disable=None)
116118

117-
generator = torch.Generator(device=torch_device).manual_seed(0)
119+
generator = torch.manual_seed(0)
118120
image = ddpm(generator=generator, output_type="numpy").images
119121

120122
image_slice = image[0, -3:, -3:, -1]
121123

122124
assert image.shape == (1, 32, 32, 3)
123-
expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020])
125+
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
124126
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

0 commit comments

Comments
 (0)