@@ -1314,6 +1314,59 @@ def test_stable_diffusion_inpaint_fp16(self):
13141314
13151315 assert image .shape == (1 , 32 , 32 , 3 )
13161316
1317+ def test_components (self ):
1318+ """Test that components property works correctly"""
1319+ unet = self .dummy_cond_unet
1320+ scheduler = PNDMScheduler (skip_prk_steps = True )
1321+ vae = self .dummy_vae
1322+ bert = self .dummy_text_encoder
1323+ tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
1324+
1325+ image = self .dummy_image .cpu ().permute (0 , 2 , 3 , 1 )[0 ]
1326+ init_image = Image .fromarray (np .uint8 (image )).convert ("RGB" )
1327+ mask_image = Image .fromarray (np .uint8 (image + 4 )).convert ("RGB" ).resize ((128 , 128 ))
1328+
1329+ # make sure here that pndm scheduler skips prk
1330+ inpaint = StableDiffusionInpaintPipeline (
1331+ unet = unet ,
1332+ scheduler = scheduler ,
1333+ vae = vae ,
1334+ text_encoder = bert ,
1335+ tokenizer = tokenizer ,
1336+ safety_checker = self .dummy_safety_checker ,
1337+ feature_extractor = self .dummy_extractor ,
1338+ )
1339+ img2img = StableDiffusionImg2ImgPipeline (** inpaint .components )
1340+ text2img = StableDiffusionPipeline (** inpaint .components )
1341+
1342+ prompt = "A painting of a squirrel eating a burger"
1343+ generator = torch .Generator (device = torch_device ).manual_seed (0 )
1344+ image_inpaint = inpaint (
1345+ [prompt ],
1346+ generator = generator ,
1347+ num_inference_steps = 2 ,
1348+ output_type = "np" ,
1349+ init_image = init_image ,
1350+ mask_image = mask_image ,
1351+ ).images
1352+ image_img2img = img2img (
1353+ [prompt ],
1354+ generator = generator ,
1355+ num_inference_steps = 2 ,
1356+ output_type = "np" ,
1357+ init_image = init_image ,
1358+ ).images
1359+ image_text2img = text2img (
1360+ [prompt ],
1361+ generator = generator ,
1362+ num_inference_steps = 2 ,
1363+ output_type = "np" ,
1364+ ).images
1365+
1366+ assert image_inpaint .shape == (1 , 32 , 32 , 3 )
1367+ assert image_img2img .shape == (1 , 32 , 32 , 3 )
1368+ assert image_text2img .shape == (1 , 128 , 128 , 3 )
1369+
13171370
13181371class PipelineTesterMixin (unittest .TestCase ):
13191372 def tearDown (self ):
0 commit comments