Skip to content

Commit a47979d

Browse files
authored
[Tests] Fix mps+generator fast tests (huggingface#1230)
* [Tests] Fix mps+generator fast tests * mps for Euler * retry * warmup issue again? * fix reproducible initial noise * Revert "fix reproducible initial noise" This reverts commit f300d05. * fix reproducible initial noise * fix device
1 parent 57e24ba commit a47979d

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

pipelines/ddim/pipeline_ddim.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __call__(
7878
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
7979
message = (
8080
f"The `generator` device is `{generator.device}` and does not match the pipeline "
81-
f"device `{self.device}`, so the `generator` will be set to `None`. "
81+
f"device `{self.device}`, so the `generator` will be ignored. "
8282
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
8383
)
8484
deprecate(
@@ -89,11 +89,13 @@ def __call__(
8989
generator = None
9090

9191
# Sample gaussian noise to begin loop
92-
image = torch.randn(
93-
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
94-
generator=generator,
95-
device=self.device,
96-
)
92+
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
93+
if self.device.type == "mps":
94+
# randn does not work reproducibly on mps
95+
image = torch.randn(image_shape, generator=generator)
96+
image = image.to(self.device)
97+
else:
98+
image = torch.randn(image_shape, generator=generator, device=self.device)
9799

98100
# set step values
99101
self.scheduler.set_timesteps(num_inference_steps)

pipelines/ddpm/pipeline_ddpm.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __call__(
8383
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
8484
message = (
8585
f"The `generator` device is `{generator.device}` and does not match the pipeline "
86-
f"device `{self.device}`, so the `generator` will be set to `None`. "
86+
f"device `{self.device}`, so the `generator` will be ignored. "
8787
f'Please use `torch.Generator(device="{self.device}")` instead.'
8888
)
8989
deprecate(
@@ -94,11 +94,13 @@ def __call__(
9494
generator = None
9595

9696
# Sample gaussian noise to begin loop
97-
image = torch.randn(
98-
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
99-
generator=generator,
100-
device=self.device,
101-
)
97+
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
98+
if self.device.type == "mps":
99+
# randn does not work reproducibly on mps
100+
image = torch.randn(image_shape, generator=generator)
101+
image = image.to(self.device)
102+
else:
103+
image = torch.randn(image_shape, generator=generator, device=self.device)
102104

103105
# set step values
104106
self.scheduler.set_timesteps(num_inference_steps)

0 commit comments

Comments
 (0)