Skip to content

Commit 6002126

Browse files
committed
fix-copies
1 parent 3c70b1e commit 6002126

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -478,11 +478,22 @@ def encode_image(self, image, device, num_images_per_prompt):
478478
image = self.feature_extractor(image, return_tensors="pt").pixel_values
479479

480480
image = image.to(device=device, dtype=dtype)
481-
image_embeds = self.image_encoder(image).image_embeds
482-
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
481+
if output_hidden_states:
482+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
483+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
484+
uncond_image_enc_hidden_states = self.image_encoder(
485+
torch.zeros_like(image), output_hidden_states=True
486+
).hidden_states[-2]
487+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
488+
num_images_per_prompt, dim=0
489+
)
490+
return image_enc_hidden_states, uncond_image_enc_hidden_states
491+
else:
492+
image_embeds = self.image_encoder(image).image_embeds
493+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
494+
uncond_image_embeds = torch.zeros_like(image_embeds)
483495

484-
uncond_image_embeds = torch.zeros_like(image_embeds)
485-
return image_embeds, uncond_image_embeds
496+
return image_embeds, uncond_image_embeds
486497

487498
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
488499
def run_safety_checker(self, image, device, dtype):

0 commit comments

Comments
 (0)