Skip to content

Commit cfd25b6

Browse files
committed
[tests] fix Pixart Sigma tests (#7966)
* checking tests * checking ii. * remove prints. * test_pixart_1024 * fix 1024.
1 parent e9eaea7 commit cfd25b6

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

tests/pipelines/pixart_sigma/test_pixart.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,23 +336,28 @@ def test_pixart_1024(self):
336336
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
337337

338338
image_slice = image[0, -3:, -3:, -1]
339-
expected_slice = np.array([0.0742, 0.0835, 0.2114, 0.0295, 0.0784, 0.2361, 0.1738, 0.2251, 0.3589])
339+
expected_slice = np.array([0.4517, 0.4446, 0.4375, 0.449, 0.4399, 0.4365, 0.4583, 0.4629, 0.4473])
340340

341341
max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
342342
self.assertLessEqual(max_diff, 1e-4)
343343

344344
def test_pixart_512(self):
345345
generator = torch.Generator("cpu").manual_seed(0)
346346

347-
pipe = PixArtSigmaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16)
347+
transformer = Transformer2DModel.from_pretrained(
348+
self.ckpt_id_512, subfolder="transformer", torch_dtype=torch.float16
349+
)
350+
pipe = PixArtSigmaPipeline.from_pretrained(
351+
self.ckpt_id_1024, transformer=transformer, torch_dtype=torch.float16
352+
)
348353
pipe.enable_model_cpu_offload()
349354

350355
prompt = self.prompt
351356

352357
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
353358

354359
image_slice = image[0, -3:, -3:, -1]
355-
expected_slice = np.array([0.3477, 0.3882, 0.4541, 0.3413, 0.3821, 0.4463, 0.4001, 0.4409, 0.4958])
360+
expected_slice = np.array([0.0479, 0.0378, 0.0217, 0.0942, 0.064, 0.0791, 0.2073, 0.1975, 0.2017])
356361

357362
max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
358363
self.assertLessEqual(max_diff, 1e-4)
@@ -394,7 +399,12 @@ def test_pixart_1024_without_resolution_binning(self):
394399
def test_pixart_512_without_resolution_binning(self):
395400
generator = torch.manual_seed(0)
396401

397-
pipe = PixArtSigmaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16)
402+
transformer = Transformer2DModel.from_pretrained(
403+
self.ckpt_id_512, subfolder="transformer", torch_dtype=torch.float16
404+
)
405+
pipe = PixArtSigmaPipeline.from_pretrained(
406+
self.ckpt_id_1024, transformer=transformer, torch_dtype=torch.float16
407+
)
398408
pipe.enable_model_cpu_offload()
399409

400410
prompt = self.prompt

0 commit comments

Comments
 (0)