Skip to content

Commit 9221da4

Browse files
authored
fix: init for vae during pixart tests (#6215)
* fix: init for vae during pixart tests * print the values * add flatten * correct assertion value for test_inference * correct assertion values for test_inference_non_square_images * run styling * debug test_inference_with_multiple_images_per_prompt * fix assertion values for test_inference_with_multiple_images_per_prompt
1 parent 57fde87 commit 9221da4

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tests/pipelines/pixart/test_pixart.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def get_dummy_components(self):
6464
norm_elementwise_affine=False,
6565
norm_eps=1e-6,
6666
)
67+
torch.manual_seed(0)
6768
vae = AutoencoderKL()
69+
6870
scheduler = DDIMScheduler()
6971
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
7072

@@ -186,7 +188,7 @@ def test_inference(self):
186188
image_slice = image[0, -3:, -3:, -1]
187189

188190
self.assertEqual(image.shape, (1, 8, 8, 3))
189-
expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
191+
expected_slice = np.array([0.6319, 0.3526, 0.3806, 0.6327, 0.4639, 0.483, 0.2583, 0.5331, 0.4852])
190192
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
191193
self.assertLessEqual(max_diff, 1e-3)
192194

@@ -203,7 +205,7 @@ def test_inference_non_square_images(self):
203205
image_slice = image[0, -3:, -3:, -1]
204206
self.assertEqual(image.shape, (1, 32, 48, 3))
205207

206-
expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416])
208+
expected_slice = np.array([0.6493, 0.537, 0.4081, 0.4762, 0.3695, 0.4711, 0.3026, 0.5218, 0.5263])
207209
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
208210
self.assertLessEqual(max_diff, 1e-3)
209211

@@ -293,7 +295,7 @@ def test_inference_with_multiple_images_per_prompt(self):
293295
image_slice = image[0, -3:, -3:, -1]
294296

295297
self.assertEqual(image.shape, (2, 8, 8, 3))
296-
expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
298+
expected_slice = np.array([0.6319, 0.3526, 0.3806, 0.6327, 0.4639, 0.483, 0.2583, 0.5331, 0.4852])
297299
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
298300
self.assertLessEqual(max_diff, 1e-3)
299301

0 commit comments

Comments
 (0)