Skip to content

Commit cbb18d8

Browse files
pcuencapatrickvonplaten
authored andcommitted
Update Flax TPU tests (huggingface#3069)
Update Flax TPU tests. Co-authored-by: Patrick von Platen <[email protected]>
1 parent 870169e commit cbb18d8

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

tests/test_pipelines_flax.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,10 @@ def test_dummy_all_tpus(self):
7878

7979
assert images.shape == (num_samples, 1, 64, 64, 3)
8080
if jax.device_count() == 8:
81-
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 3.1111548) < 1e-3
82-
assert np.abs(np.abs(images, dtype=np.float32).sum() - 199746.95) < 5e-1
81+
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.1514745) < 1e-3
82+
assert np.abs(np.abs(images, dtype=np.float32).sum() - 49947.875) < 5e-1
8383

8484
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
85-
8685
assert len(images_pil) == num_samples
8786

8887
def test_stable_diffusion_v1_4(self):
@@ -140,8 +139,8 @@ def test_stable_diffusion_v1_4_bfloat_16(self):
140139

141140
assert images.shape == (num_samples, 1, 512, 512, 3)
142141
if jax.device_count() == 8:
143-
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
144-
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
142+
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 1e-3
143+
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
145144

146145
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
147146
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
@@ -169,8 +168,8 @@ def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
169168

170169
assert images.shape == (num_samples, 1, 512, 512, 3)
171170
if jax.device_count() == 8:
172-
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
173-
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
171+
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 1e-3
172+
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
174173

175174
def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
176175
scheduler = FlaxDDIMScheduler(

0 commit comments

Comments
 (0)