Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ jobs:
- name: Run fast PyTorch tests on M1 (MPS)
shell: arch -arch arm64 bash {0}
run: |
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/

- name: Failure short reports
if: ${{ failure() }}
Expand Down
14 changes: 8 additions & 6 deletions src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __call__(
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. "
f"device `{self.device}`, so the `generator` will be ignored. "
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
)
deprecate(
Expand All @@ -89,11 +89,13 @@ def __call__(
generator = None

# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
device=self.device,
)
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
image = image.to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=self.device)

# set step values
self.scheduler.set_timesteps(num_inference_steps)
Expand Down
14 changes: 8 additions & 6 deletions src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __call__(
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. "
f"device `{self.device}`, so the `generator` will be ignored. "
f'Please use `torch.Generator(device="{self.device}")` instead.'
)
deprecate(
Expand All @@ -94,11 +94,13 @@ def __call__(
generator = None

# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
device=self.device,
)
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
image = image.to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=self.device)

# set step values
self.scheduler.set_timesteps(num_inference_steps)
Expand Down
8 changes: 6 additions & 2 deletions tests/pipelines/ddpm/test_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,14 @@ def test_inference_predict_epsilon(self):
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)

generator = torch.Generator(device=torch_device).manual_seed(0)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images

generator = torch.Generator(device=torch_device).manual_seed(0)
generator = generator.manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]

image_slice = image[0, -3:, -3:, -1]
Expand Down
26 changes: 21 additions & 5 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,11 @@ def test_full_loop_no_noise(self):

scheduler.set_timesteps(self.num_inference_steps)

generator = torch.Generator(torch_device).manual_seed(0)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
Expand All @@ -1308,7 +1312,11 @@ def test_full_loop_device(self):

scheduler.set_timesteps(self.num_inference_steps, device=torch_device)

generator = torch.Generator(torch_device).manual_seed(0)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
Expand Down Expand Up @@ -1364,7 +1372,11 @@ def test_full_loop_no_noise(self):

scheduler.set_timesteps(self.num_inference_steps)

generator = torch.Generator(device=torch_device).manual_seed(0)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
Expand All @@ -1381,7 +1393,7 @@ def test_full_loop_no_noise(self):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

if str(torch_device).startswith("cpu"):
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3
else:
Expand All @@ -1396,7 +1408,11 @@ def test_full_loop_device(self):

scheduler.set_timesteps(self.num_inference_steps, device=torch_device)

generator = torch.Generator(device=torch_device).manual_seed(0)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
Expand Down