Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,14 @@ def step(
# randn does not work reproducibly on mps
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
elif generator is not None and generator.device.type == device.type:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
else:
variance_noise = torch.randn(model_output.shape, generator=generator, dtype=model_output.dtype)
variance_noise = variance_noise.to(device)

variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise

prev_sample = prev_sample + variance
Expand Down
24 changes: 24 additions & 0 deletions tests/pipelines/ddim/test_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,30 @@ def test_inference(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance

def test_inference_eta(self):
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler()

ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)

# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddim(num_inference_steps=1)

generator = torch.manual_seed(0)
# set eta > 0 to test the variance noise generation
image = ddim(generator=generator, eta=1.0, num_inference_steps=2, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
Comment on lines +86 to +89
Copy link
Member Author

@anton-l anton-l Nov 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tests fails with a device mismatch without the proposed change


assert image.shape == (1, 32, 32, 3)
expected_slice = np.array(
[[9.955e-01, 5.785e-01, 4.674e-01, 9.93e-01, 0.0e00, 1.0e00, 1.2e-03, 2.73e-04, 5.19e-04]]
)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance


@slow
@require_torch
Expand Down
43 changes: 5 additions & 38 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer

Expand Down Expand Up @@ -462,40 +461,9 @@ def test_output_format(self):
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)

# Make sure the test passes for different values of random seed
@parameterized.expand([(0,), (4,)])
def test_ddpm_ddim_equality(self, seed):
model_id = "google/ddpm-cifar10-32"

unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler()

ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)

generator = torch.manual_seed(seed)
ddpm_image = ddpm(generator=generator, output_type="numpy").images

generator = torch.manual_seed(seed)
ddim_image = ddim(
generator=generator,
num_inference_steps=1000,
eta=1.0,
output_type="numpy",
use_clipped_model_output=True, # Need this to make DDIM match DDPM
).images

# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_image - ddim_image).max() < 1e-1

# Make sure the test passes for different values of random seed
@parameterized.expand([(0,), (4,)])
def test_ddpm_ddim_equality_batched(self, seed):
def test_ddpm_ddim_equality_batched(self):
seed = 0
batch_size = 2
model_id = "google/ddpm-cifar10-32"

unet = UNet2DModel.from_pretrained(model_id)
Expand All @@ -511,17 +479,16 @@ def test_ddpm_ddim_equality_batched(self, seed):
ddim.set_progress_bar_config(disable=None)

generator = torch.manual_seed(seed)
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
ddpm_images = ddpm(batch_size=batch_size, generator=generator, output_type="numpy").images

generator = torch.manual_seed(seed)
ddim_images = ddim(
batch_size=4,
batch_size=batch_size,
generator=generator,
num_inference_steps=1000,
eta=1.0,
output_type="numpy",
use_clipped_model_output=True, # Need this to make DDIM match DDPM
).images

# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1