diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 75cef635d063..10494ca9ab26 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -300,10 +300,14 @@ def step( # randn does not work reproducibly on mps variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) variance_noise = variance_noise.to(device) - else: + elif generator is not None and generator.device.type == device.type: variance_noise = torch.randn( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) + else: + variance_noise = torch.randn(model_output.shape, generator=generator, dtype=model_output.dtype) + variance_noise = variance_noise.to(device) + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise prev_sample = prev_sample + variance diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py index 4445fe7feecf..288ae844e17c 100644 --- a/tests/pipelines/ddim/test_ddim.py +++ b/tests/pipelines/ddim/test_ddim.py @@ -71,6 +71,30 @@ def test_inference(self): assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance + def test_inference_eta(self): + unet = self.dummy_uncond_unet + scheduler = DDIMScheduler() + + ddim = DDIMPipeline(unet=unet, scheduler=scheduler) + ddim.to(torch_device) + ddim.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddim(num_inference_steps=1) + + generator = torch.manual_seed(0) + # set eta > 0 to test the variance noise generation + image = ddim(generator=generator, eta=1.0, num_inference_steps=2, output_type="numpy").images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array( + [[9.955e-01, 5.785e-01, 4.674e-01, 9.93e-01, 0.0e00, 1.0e00, 1.2e-03, 2.73e-04, 5.19e-04]] + ) + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance + @slow @require_torch diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 753c821dd315..34f5acccc316 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -42,7 +42,6 @@ from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu -from parameterized import parameterized from PIL import Image from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -462,40 +461,9 @@ def test_output_format(self): assert isinstance(images, list) assert isinstance(images[0], PIL.Image.Image) - # Make sure the test passes for different values of random seed - @parameterized.expand([(0,), (4,)]) - def test_ddpm_ddim_equality(self, seed): - model_id = "google/ddpm-cifar10-32" - - unet = UNet2DModel.from_pretrained(model_id) - ddpm_scheduler = DDPMScheduler() - ddim_scheduler = DDIMScheduler() - - ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) - ddpm.to(torch_device) - ddpm.set_progress_bar_config(disable=None) - ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler) - ddim.to(torch_device) - ddim.set_progress_bar_config(disable=None) - - generator = torch.manual_seed(seed) - ddpm_image = ddpm(generator=generator, output_type="numpy").images - - generator = torch.manual_seed(seed) - ddim_image = ddim( - generator=generator, - num_inference_steps=1000, - eta=1.0, - output_type="numpy", - use_clipped_model_output=True, # Need this to make DDIM match DDPM - ).images - - # the values aren't exactly equal, but the images look the same visually - assert np.abs(ddpm_image - ddim_image).max() < 1e-1 - - # Make sure the test passes for different values of random seed - @parameterized.expand([(0,), (4,)]) - def test_ddpm_ddim_equality_batched(self, seed): + def test_ddpm_ddim_equality_batched(self): + seed = 0 + batch_size = 2 model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) @@ -511,11 +479,11 @@ def test_ddpm_ddim_equality_batched(self, seed): ddim.set_progress_bar_config(disable=None) generator = torch.manual_seed(seed) - ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images + ddpm_images = ddpm(batch_size=batch_size, generator=generator, output_type="numpy").images generator = torch.manual_seed(seed) ddim_images = ddim( - batch_size=4, + batch_size=batch_size, generator=generator, num_inference_steps=1000, eta=1.0, @@ -523,5 +491,4 @@ def test_ddpm_ddim_equality_batched(self, seed): use_clipped_model_output=True, # Need this to make DDIM match DDPM ).images - # the values aren't exactly equal, but the images look the same visually assert np.abs(ddpm_images - ddim_images).max() < 1e-1