diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index a16b3782a421..c3ea0045b4c1 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -66,32 +66,6 @@ def test_fast_inference(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - def test_full_inference(self): - device = "cpu" - unet = self.dummy_uncond_unet - scheduler = DDPMScheduler() - - ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) - ddpm.to(device) - ddpm.set_progress_bar_config(disable=None) - - generator = torch.Generator(device=device).manual_seed(0) - image = ddpm(generator=generator, output_type="numpy").images - - generator = torch.Generator(device=device).manual_seed(0) - image_from_tuple = ddpm(generator=generator, output_type="numpy", return_dict=False)[0] - - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array( - [1.0, 3.495e-02, 2.939e-01, 9.821e-01, 9.448e-01, 6.261e-03, 7.998e-01, 8.9e-01, 1.122e-02] - ) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - def test_inference_predict_sample(self): unet = self.dummy_uncond_unet scheduler = DDPMScheduler(prediction_type="sample")