@@ -336,23 +336,28 @@ def test_pixart_1024(self):
336336 image = pipe (prompt , generator = generator , num_inference_steps = 2 , output_type = "np" ).images
337337
338338 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
339- expected_slice = np .array ([0.0742 , 0.0835 , 0.2114 , 0.0295 , 0.0784 , 0.2361 , 0.1738 , 0.2251 , 0.3589 ])
339+ expected_slice = np .array ([0.4517 , 0.4446 , 0.4375 , 0.449 , 0.4399 , 0.4365 , 0.4583 , 0.4629 , 0.4473 ])
340340
341341 max_diff = numpy_cosine_similarity_distance (image_slice .flatten (), expected_slice )
342342 self .assertLessEqual (max_diff , 1e-4 )
343343
344344 def test_pixart_512 (self ):
345345 generator = torch .Generator ("cpu" ).manual_seed (0 )
346346
347- pipe = PixArtSigmaPipeline .from_pretrained (self .ckpt_id_512 , torch_dtype = torch .float16 )
347+ transformer = Transformer2DModel .from_pretrained (
348+ self .ckpt_id_512 , subfolder = "transformer" , torch_dtype = torch .float16
349+ )
350+ pipe = PixArtSigmaPipeline .from_pretrained (
351+ self .ckpt_id_1024 , transformer = transformer , torch_dtype = torch .float16
352+ )
348353 pipe .enable_model_cpu_offload ()
349354
350355 prompt = self .prompt
351356
352357 image = pipe (prompt , generator = generator , num_inference_steps = 2 , output_type = "np" ).images
353358
354359 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
355- expected_slice = np .array ([0.3477 , 0.3882 , 0.4541 , 0.3413 , 0.3821 , 0.4463 , 0.4001 , 0.4409 , 0.4958 ])
360+ expected_slice = np .array ([0.0479 , 0.0378 , 0.0217 , 0.0942 , 0.064 , 0.0791 , 0.2073 , 0.1975 , 0.2017 ])
356361
357362 max_diff = numpy_cosine_similarity_distance (image_slice .flatten (), expected_slice )
358363 self .assertLessEqual (max_diff , 1e-4 )
@@ -394,7 +399,12 @@ def test_pixart_1024_without_resolution_binning(self):
394399 def test_pixart_512_without_resolution_binning (self ):
395400 generator = torch .manual_seed (0 )
396401
397- pipe = PixArtSigmaPipeline .from_pretrained (self .ckpt_id_512 , torch_dtype = torch .float16 )
402+ transformer = Transformer2DModel .from_pretrained (
403+ self .ckpt_id_512 , subfolder = "transformer" , torch_dtype = torch .float16
404+ )
405+ pipe = PixArtSigmaPipeline .from_pretrained (
406+ self .ckpt_id_1024 , transformer = transformer , torch_dtype = torch .float16
407+ )
398408 pipe .enable_model_cpu_offload ()
399409
400410 prompt = self .prompt
0 commit comments