From 4aa7079664b2467606f1c04270967f68d2b78ee2 Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 9 Nov 2022 23:12:17 +0100 Subject: [PATCH 1/8] [Tests] Fix mps+generator fast tests --- tests/pipelines/ddpm/test_ddpm.py | 8 ++++++-- tests/test_scheduler.py | 24 ++++++++++++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index e16e0d6e8cbd..14bc09469773 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -81,10 +81,14 @@ def test_inference_predict_epsilon(self): if torch_device == "mps": _ = ddpm(num_inference_steps=1) - generator = torch.Generator(device=torch_device).manual_seed(0) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images - generator = torch.Generator(device=torch_device).manual_seed(0) + generator = generator.manual_seed(0) image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] image_slice = image[0, -3:, -3:, -1] diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index ab5217151125..2e5b27bf2960 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1281,7 +1281,11 @@ def test_full_loop_no_noise(self): scheduler.set_timesteps(self.num_inference_steps) - generator = torch.Generator(torch_device).manual_seed(0) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) model = self.dummy_model() sample = self.dummy_sample_deter * scheduler.init_noise_sigma @@ -1308,7 +1312,11 @@ def test_full_loop_device(self): scheduler.set_timesteps(self.num_inference_steps, device=torch_device) - generator = torch.Generator(torch_device).manual_seed(0) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) model = self.dummy_model() sample = self.dummy_sample_deter * scheduler.init_noise_sigma @@ -1364,7 +1372,11 @@ def test_full_loop_no_noise(self): scheduler.set_timesteps(self.num_inference_steps) - generator = torch.Generator(device=torch_device).manual_seed(0) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) model = self.dummy_model() sample = self.dummy_sample_deter * scheduler.init_noise_sigma @@ -1396,7 +1408,11 @@ def test_full_loop_device(self): scheduler.set_timesteps(self.num_inference_steps, device=torch_device) - generator = torch.Generator(device=torch_device).manual_seed(0) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) model = self.dummy_model() sample = self.dummy_sample_deter * scheduler.init_noise_sigma From d36c5f89be21ff1edbac5c5e2751b94d4eb324b8 Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 9 Nov 2022 23:21:04 +0100 Subject: [PATCH 2/8] mps for Euler --- tests/test_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 2e5b27bf2960..a9770f0a54a8 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1393,7 +1393,7 @@ def test_full_loop_no_noise(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - if str(torch_device).startswith("cpu"): + if torch_device in ["cpu", "mps"]: assert abs(result_sum.item() - 152.3192) < 1e-2 assert abs(result_mean.item() - 0.1983) < 1e-3 else: From 85ec36bac74813a8238ce4b905d9d08c51ceb3fe Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 9 Nov 2022 23:36:56 +0100 Subject: [PATCH 3/8] retry --- tests/pipelines/ddpm/test_ddpm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index 14bc09469773..f4517907cde3 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -88,7 +88,11 @@ def test_inference_predict_epsilon(self): generator = torch.Generator(device=torch_device).manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images - generator = generator.manual_seed(0) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] image_slice = image[0, -3:, -3:, -1] From 86d4c5a25428c801fa720f7150a8bafeb6ff98e2 Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 9 Nov 2022 23:47:49 +0100 Subject: [PATCH 4/8] warmup issue again? --- .github/workflows/pr_tests.yml | 2 +- src/diffusers/pipelines/ddim/pipeline_ddim.py | 2 +- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 2 +- tests/pipelines/ddpm/test_ddpm.py | 8 +++----- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index c978efe3b7db..dc1c482aa098 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -136,7 +136,7 @@ jobs: - name: Run fast PyTorch tests on M1 (MPS) shell: arch -arch arm64 bash {0} run: | - ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/ + ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/ - name: Failure short reports if: ${{ failure() }} diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index d0bca8038ec4..79ab9e2dc871 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -78,7 +78,7 @@ def __call__( if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": message = ( f"The `generator` device is `{generator.device}` and does not match the pipeline " - f"device `{self.device}`, so the `generator` will be set to `None`. " + f"device `{self.device}`, so the `generator` will be ignored. " f'Please use `generator=torch.Generator(device="{self.device}")` instead.' ) deprecate( diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index d145c5d518a1..04b7e65f4849 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -83,7 +83,7 @@ def __call__( if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": message = ( f"The `generator` device is `{generator.device}` and does not match the pipeline " - f"device `{self.device}`, so the `generator` will be set to `None`. " + f"device `{self.device}`, so the `generator` will be ignored. " f'Please use `torch.Generator(device="{self.device}")` instead.' ) deprecate( diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index f4517907cde3..4d59d08c93aa 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -80,7 +80,6 @@ def test_inference_predict_epsilon(self): # Warmup pass when using mps (see #372) if torch_device == "mps": _ = ddpm(num_inference_steps=1) - if torch_device == "mps": # device type MPS is not supported for torch.Generator() api. generator = torch.manual_seed(0) @@ -88,11 +87,10 @@ def test_inference_predict_epsilon(self): generator = torch.Generator(device=torch_device).manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + # Warmup pass when using mps (see #372) if torch_device == "mps": - # device type MPS is not supported for torch.Generator() api. - generator = torch.manual_seed(0) - else: - generator = torch.Generator(device=torch_device).manual_seed(0) + _ = ddpm(num_inference_steps=1) + generator = generator.manual_seed(0) image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] image_slice = image[0, -3:, -3:, -1] From f300d05cb9782ed320064a0c58577a32d4139e6d Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 10 Nov 2022 00:00:07 +0100 Subject: [PATCH 5/8] fix reproducible initial noise --- src/diffusers/pipelines/ddim/pipeline_ddim.py | 12 ++++---- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 12 ++++---- tests/pipelines/ddpm/test_ddpm.py | 28 ++++++++++--------- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 79ab9e2dc871..c68e8240899d 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -89,11 +89,13 @@ def __call__( generator = None # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - device=self.device, - ) + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = torch.randn(image_shape, generator=generator) + image = image.to(device) + else: + image = torch.randn(image_shape, generator=generator, device=self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 04b7e65f4849..f28f4406e78e 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -94,11 +94,13 @@ def __call__( generator = None # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - device=self.device, - ) + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = torch.randn(image_shape, generator=generator) + image = image.to(device) + else: + image = torch.randn(image_shape, generator=generator, device=self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index 4d59d08c93aa..e335c5707861 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -20,7 +20,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers.utils import deprecate -from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device +from diffusers.utils.testing_utils import require_torch, slow, torch_device from ...test_pipelines_common import PipelineTesterMixin @@ -44,18 +44,21 @@ def dummy_uncond_unet(self): return model def test_inference(self): - device = "cpu" unet = self.dummy_uncond_unet scheduler = DDPMScheduler() ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) - ddpm.to(device) + ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) - generator = torch.Generator(device=device).manual_seed(0) + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=1) + + generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images - generator = torch.Generator(device=device).manual_seed(0) + generator = torch.manual_seed(0) image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] @@ -65,8 +68,9 @@ def test_inference(self): expected_slice = np.array( [5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02] ) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance def test_inference_predict_epsilon(self): deprecate("remove this test", "0.10.0", "remove") @@ -80,6 +84,7 @@ def test_inference_predict_epsilon(self): # Warmup pass when using mps (see #372) if torch_device == "mps": _ = ddpm(num_inference_steps=1) + if torch_device == "mps": # device type MPS is not supported for torch.Generator() api. generator = torch.manual_seed(0) @@ -87,9 +92,6 @@ def test_inference_predict_epsilon(self): generator = torch.Generator(device=torch_device).manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images - # Warmup pass when using mps (see #372) - if torch_device == "mps": - _ = ddpm(num_inference_steps=1) generator = generator.manual_seed(0) image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] @@ -102,7 +104,7 @@ def test_inference_predict_epsilon(self): @slow -@require_torch_gpu +@require_torch class DDPMPipelineIntegrationTests(unittest.TestCase): def test_inference_cifar10(self): model_id = "google/ddpm-cifar10-32" @@ -114,11 +116,11 @@ def test_inference_cifar10(self): ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) - generator = torch.Generator(device=torch_device).manual_seed(0) + generator = torch.manual_seed(0) image = ddpm(generator=generator, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020]) + expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 From b424987099db6e41008b6fa295e93432d9a87953 Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 10 Nov 2022 00:03:17 +0100 Subject: [PATCH 6/8] Revert "fix reproducible initial noise" This reverts commit f300d05cb9782ed320064a0c58577a32d4139e6d. --- src/diffusers/pipelines/ddim/pipeline_ddim.py | 12 ++++---- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 12 ++++---- tests/pipelines/ddpm/test_ddpm.py | 28 +++++++++---------- 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index c68e8240899d..79ab9e2dc871 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -89,13 +89,11 @@ def __call__( generator = None # Sample gaussian noise to begin loop - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) - if self.device.type == "mps": - # randn does not work reproducibly on mps - image = torch.randn(image_shape, generator=generator) - image = image.to(device) - else: - image = torch.randn(image_shape, generator=generator, device=self.device) + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + device=self.device, + ) # set step values self.scheduler.set_timesteps(num_inference_steps) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index f28f4406e78e..04b7e65f4849 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -94,13 +94,11 @@ def __call__( generator = None # Sample gaussian noise to begin loop - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) - if self.device.type == "mps": - # randn does not work reproducibly on mps - image = torch.randn(image_shape, generator=generator) - image = image.to(device) - else: - image = torch.randn(image_shape, generator=generator, device=self.device) + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + device=self.device, + ) # set step values self.scheduler.set_timesteps(num_inference_steps) diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index e335c5707861..4d59d08c93aa 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -20,7 +20,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers.utils import deprecate -from diffusers.utils.testing_utils import require_torch, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device from ...test_pipelines_common import PipelineTesterMixin @@ -44,21 +44,18 @@ def dummy_uncond_unet(self): return model def test_inference(self): + device = "cpu" unet = self.dummy_uncond_unet scheduler = DDPMScheduler() ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) - ddpm.to(torch_device) + ddpm.to(device) ddpm.set_progress_bar_config(disable=None) - # Warmup pass when using mps (see #372) - if torch_device == "mps": - _ = ddpm(num_inference_steps=1) - - generator = torch.manual_seed(0) + generator = torch.Generator(device=device).manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images - generator = torch.manual_seed(0) + generator = torch.Generator(device=device).manual_seed(0) image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] @@ -68,9 +65,8 @@ def test_inference(self): expected_slice = np.array( [5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02] ) - tolerance = 1e-2 if torch_device != "mps" else 3e-2 - assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_inference_predict_epsilon(self): deprecate("remove this test", "0.10.0", "remove") @@ -84,7 +80,6 @@ def test_inference_predict_epsilon(self): # Warmup pass when using mps (see #372) if torch_device == "mps": _ = ddpm(num_inference_steps=1) - if torch_device == "mps": # device type MPS is not supported for torch.Generator() api. generator = torch.manual_seed(0) @@ -92,6 +87,9 @@ def test_inference_predict_epsilon(self): generator = torch.Generator(device=torch_device).manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=1) generator = generator.manual_seed(0) image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] @@ -104,7 +102,7 @@ def test_inference_predict_epsilon(self): @slow -@require_torch +@require_torch_gpu class DDPMPipelineIntegrationTests(unittest.TestCase): def test_inference_cifar10(self): model_id = "google/ddpm-cifar10-32" @@ -116,11 +114,11 @@ def test_inference_cifar10(self): ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) - generator = torch.manual_seed(0) + generator = torch.Generator(device=torch_device).manual_seed(0) image = ddpm(generator=generator, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845]) + expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 From 59a3dce1bd8c93714e7be79544099d420d881fe7 Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 10 Nov 2022 00:04:42 +0100 Subject: [PATCH 7/8] fix reproducible initial noise --- src/diffusers/pipelines/ddim/pipeline_ddim.py | 12 +++++++----- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 12 +++++++----- tests/pipelines/ddpm/test_ddpm.py | 4 +--- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 79ab9e2dc871..c68e8240899d 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -89,11 +89,13 @@ def __call__( generator = None # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - device=self.device, - ) + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = torch.randn(image_shape, generator=generator) + image = image.to(device) + else: + image = torch.randn(image_shape, generator=generator, device=self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 04b7e65f4849..f28f4406e78e 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -94,11 +94,13 @@ def __call__( generator = None # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - device=self.device, - ) + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = torch.randn(image_shape, generator=generator) + image = image.to(device) + else: + image = torch.randn(image_shape, generator=generator, device=self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index 4d59d08c93aa..14bc09469773 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -80,6 +80,7 @@ def test_inference_predict_epsilon(self): # Warmup pass when using mps (see #372) if torch_device == "mps": _ = ddpm(num_inference_steps=1) + if torch_device == "mps": # device type MPS is not supported for torch.Generator() api. generator = torch.manual_seed(0) @@ -87,9 +88,6 @@ def test_inference_predict_epsilon(self): generator = torch.Generator(device=torch_device).manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images - # Warmup pass when using mps (see #372) - if torch_device == "mps": - _ = ddpm(num_inference_steps=1) generator = generator.manual_seed(0) image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] From 963bde77fec045f55613aabf7d3e81bdda21bbd2 Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 10 Nov 2022 00:06:51 +0100 Subject: [PATCH 8/8] fix device --- src/diffusers/pipelines/ddim/pipeline_ddim.py | 2 +- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index c68e8240899d..6db6298329a7 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -93,7 +93,7 @@ def __call__( if self.device.type == "mps": # randn does not work reproducibly on mps image = torch.randn(image_shape, generator=generator) - image = image.to(device) + image = image.to(self.device) else: image = torch.randn(image_shape, generator=generator, device=self.device) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index f28f4406e78e..b7194664f4c4 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -98,7 +98,7 @@ def __call__( if self.device.type == "mps": # randn does not work reproducibly on mps image = torch.randn(image_shape, generator=generator) - image = image.to(device) + image = image.to(self.device) else: image = torch.randn(image_shape, generator=generator, device=self.device)