@@ -478,11 +478,22 @@ def encode_image(self, image, device, num_images_per_prompt):
478478 image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
479479
480480 image = image .to (device = device , dtype = dtype )
481- image_embeds = self .image_encoder (image ).image_embeds
482- image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
481+ if output_hidden_states :
482+ image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
483+ image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
484+ uncond_image_enc_hidden_states = self .image_encoder (
485+ torch .zeros_like (image ), output_hidden_states = True
486+ ).hidden_states [- 2 ]
487+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
488+ num_images_per_prompt , dim = 0
489+ )
490+ return image_enc_hidden_states , uncond_image_enc_hidden_states
491+ else :
492+ image_embeds = self .image_encoder (image ).image_embeds
493+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
494+ uncond_image_embeds = torch .zeros_like (image_embeds )
483495
484- uncond_image_embeds = torch .zeros_like (image_embeds )
485- return image_embeds , uncond_image_embeds
496+ return image_embeds , uncond_image_embeds
486497
487498 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
488499 def run_safety_checker (self , image , device , dtype ):
0 commit comments