diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py index 0f0654397a34..2078a592ceca 100644 --- a/tests/pipelines/ddim/test_ddim.py +++ b/tests/pipelines/ddim/test_ddim.py @@ -42,9 +42,10 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, + block_out_channels=(4, 8), + layers_per_block=1, + norm_num_groups=4, + sample_size=8, in_channels=3, out_channels=3, down_block_types=("DownBlock2D", "AttnDownBlock2D"), @@ -79,10 +80,8 @@ def test_inference(self): image = pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] - self.assertEqual(image.shape, (1, 32, 32, 3)) - expected_slice = np.array( - [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] - ) + self.assertEqual(image.shape, (1, 8, 8, 3)) + expected_slice = np.array([0.0, 9.979e-01, 0.0, 9.999e-01, 9.986e-01, 9.991e-01, 7.106e-04, 0.0, 0.0]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3)