@@ -78,11 +78,10 @@ def test_dummy_all_tpus(self):
78
78
79
79
assert images .shape == (num_samples , 1 , 64 , 64 , 3 )
80
80
if jax .device_count () == 8 :
81
- assert np .abs (np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 3.1111548 ) < 1e-3
82
- assert np .abs (np .abs (images , dtype = np .float32 ).sum () - 199746.95 ) < 5e-1
81
+ assert np .abs (np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 4.1514745 ) < 1e-3
82
+ assert np .abs (np .abs (images , dtype = np .float32 ).sum () - 49947.875 ) < 5e-1
83
83
84
84
images_pil = pipeline .numpy_to_pil (np .asarray (images .reshape ((num_samples ,) + images .shape [- 3 :])))
85
-
86
85
assert len (images_pil ) == num_samples
87
86
88
87
def test_stable_diffusion_v1_4 (self ):
@@ -140,8 +139,8 @@ def test_stable_diffusion_v1_4_bfloat_16(self):
140
139
141
140
assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
142
141
if jax .device_count () == 8 :
143
- assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.06652832 )) < 1e-3
144
- assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2384849.8 )) < 5e-1
142
+ assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.04003906 )) < 1e-3
143
+ assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2373516.75 )) < 5e-1
145
144
146
145
def test_stable_diffusion_v1_4_bfloat_16_with_safety (self ):
147
146
pipeline , params = FlaxStableDiffusionPipeline .from_pretrained (
@@ -169,8 +168,8 @@ def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
169
168
170
169
assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
171
170
if jax .device_count () == 8 :
172
- assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.06652832 )) < 1e-3
173
- assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2384849.8 )) < 5e-1
171
+ assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.04003906 )) < 1e-3
172
+ assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2373516.75 )) < 5e-1
174
173
175
174
def test_stable_diffusion_v1_4_bfloat_16_ddim (self ):
176
175
scheduler = FlaxDDIMScheduler (
0 commit comments