diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 30026cd89ff9..a0fb3df9a5cd 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -400,7 +400,6 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -627,7 +626,6 @@ def __init__(self, slice_size): def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape diff --git a/src/diffusers/pipelines/unclip/text_proj.py b/src/diffusers/pipelines/unclip/text_proj.py index 0a54c3319f28..0414559500c1 100644 --- a/src/diffusers/pipelines/unclip/text_proj.py +++ b/src/diffusers/pipelines/unclip/text_proj.py @@ -77,10 +77,10 @@ def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) + clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1) text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) - text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1) - text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=2) + text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1) return text_encoder_hidden_states, additive_clip_time_embeddings diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index ff32ac5f9aaf..304f5f286830 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -54,6 +54,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa "decoder_num_inference_steps", "super_res_num_inference_steps", ] + test_xformers_attention = False @property def text_embedder_hidden_size(self):