Skip to content

Commit f6feb69

Browse files
authored
Relax DiT test (#2808)
* Relax DiT test * relax 2 more tests * fix style * skip test on mac due to older protobuf
1 parent 37a44bb commit f6feb69

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

tests/pipelines/dit/test_dit.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121

2222
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel
23-
from diffusers.utils import load_numpy, slow
23+
from diffusers.utils import is_xformers_available, load_numpy, slow, torch_device
2424
from diffusers.utils.testing_utils import require_torch_gpu
2525

2626
from ...pipeline_params import (
@@ -97,7 +97,14 @@ def test_inference(self):
9797
self.assertLessEqual(max_diff, 1e-3)
9898

9999
def test_inference_batch_single_identical(self):
100-
self._test_inference_batch_single_identical(relax_max_difference=True)
100+
self._test_inference_batch_single_identical(relax_max_difference=True, expected_max_diff=1e-3)
101+
102+
@unittest.skipIf(
103+
torch_device != "cuda" or not is_xformers_available(),
104+
reason="XFormers attention is only available with CUDA and `xformers` installed",
105+
)
106+
def test_xformers_attention_forwardGenerator_pass(self):
107+
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
101108

102109

103110
@require_torch_gpu
@@ -123,7 +130,7 @@ def test_dit_256(self):
123130
expected_image = load_numpy(
124131
f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}.npy"
125132
)
126-
assert np.abs((expected_image - image).max()) < 1e-3
133+
assert np.abs((expected_image - image).max()) < 1e-2
127134

128135
def test_dit_512(self):
129136
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512")

tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ def test_inference_batch_single_identical(self):
153153
def test_inference_batch_consistent(self):
154154
pass
155155

156+
@skip_mps
157+
def test_progress_bar(self):
158+
return super().test_progress_bar()
159+
156160

157161
@slow
158162
@require_torch_gpu

0 commit comments

Comments
 (0)