diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py index 937265ab05e0..ade86e183dae 100644 --- a/tests/pipelines/dit/test_dit.py +++ b/tests/pipelines/dit/test_dit.py @@ -47,10 +47,10 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) transformer = Transformer2DModel( - sample_size=16, - num_layers=2, - patch_size=4, - attention_head_dim=8, + sample_size=2, + num_layers=1, + patch_size=1, + attention_head_dim=2, num_attention_heads=2, in_channels=4, out_channels=8, @@ -90,8 +90,8 @@ def test_inference(self): image = pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] - self.assertEqual(image.shape, (1, 16, 16, 3)) - expected_slice = np.array([0.2946, 0.6601, 0.4329, 0.3296, 0.4144, 0.5319, 0.7273, 0.5013, 0.4457]) + self.assertEqual(image.shape, (1, 2, 2, 3)) + expected_slice = np.array([0.485, 0.6022, 0.5567, 0.6807]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3)