diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index c0cce3a2f237..f6d0821da4c2 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -30,9 +30,10 @@ class DDPMPipelineFastTests(unittest.TestCase): def dummy_uncond_unet(self): torch.manual_seed(0) model = UNet2DModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, + block_out_channels=(4, 8), + layers_per_block=1, + norm_num_groups=4, + sample_size=8, in_channels=3, out_channels=3, down_block_types=("DownBlock2D", "AttnDownBlock2D"), @@ -58,10 +59,8 @@ def test_fast_inference(self): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array( - [9.956e-01, 5.785e-01, 4.675e-01, 9.930e-01, 0.0, 1.000, 1.199e-03, 2.648e-04, 5.101e-04] - ) + assert image.shape == (1, 8, 8, 3) + expected_slice = np.array([0.0, 0.9996672, 0.00329116, 1.0, 0.9995991, 1.0, 0.0060907, 0.00115037, 0.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -83,7 +82,7 @@ def test_inference_predict_sample(self): image_slice = image[0, -3:, -3:, -1] image_eps_slice = image_eps[0, -3:, -3:, -1] - assert image.shape == (1, 32, 32, 3) + assert image.shape == (1, 8, 8, 3) tolerance = 1e-2 if torch_device != "mps" else 3e-2 assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance