From a65dd580779f23d98a0fd3d0db3ded32d917c681 Mon Sep 17 00:00:00 2001 From: William Berman Date: Sat, 8 Apr 2023 16:50:04 -0700 Subject: [PATCH] add AttnAddedKVProcessor2_0 block --- src/diffusers/models/attention_processor.py | 78 +++++++++++++++++-- src/diffusers/models/unet_2d_blocks.py | 23 +++++- .../versatile_diffusion/modeling_text_unet.py | 13 +++- .../unclip/test_unclip_image_variation.py | 7 +- 4 files changed, 109 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 41baf999999d..f2a5a376bf39 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -255,11 +255,15 @@ def batch_to_head_dim(self, tensor): tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor - def head_to_batch_dim(self, tensor): + def head_to_batch_dim(self, tensor, out_dim=3): head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor def get_attention_scores(self, query, key, attention_mask=None): @@ -293,7 +297,7 @@ def get_attention_scores(self, query, key, attention_mask=None): return attention_probs - def prepare_attention_mask(self, attention_mask, target_length, batch_size=None): + def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3): if batch_size is None: deprecate( "batch_size=None", @@ -320,8 +324,13 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None) else: attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + return attention_mask def norm_encoder_hidden_states(self, encoder_hidden_states): @@ -499,6 +508,64 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states +class AttnAddedKVProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query, out_dim=4) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key, out_dim=4) + value = attn.head_to_batch_dim(value, out_dim=4) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + class XFormersAttnProcessor: def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op @@ -764,6 +831,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, SlicedAttnProcessor, AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, + AttnAddedKVProcessor2_0, LoRAAttnProcessor, LoRAXFormersAttnProcessor, ] diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 08578c81091e..439c5c34b601 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,10 +15,11 @@ import numpy as np import torch +import torch.nn.functional as F from torch import nn from .attention import AdaGroupNorm, AttentionBlock -from .attention_processor import Attention, AttnAddedKVProcessor +from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel @@ -612,6 +613,10 @@ def __init__( attentions = [] for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + attentions.append( Attention( query_dim=in_channels, @@ -624,7 +629,7 @@ def __init__( upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, - processor=AttnAddedKVProcessor(), + processor=processor, ) ) resnets.append( @@ -1396,6 +1401,11 @@ def __init__( skip_time_act=skip_time_act, ) ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + attentions.append( Attention( query_dim=out_channels, @@ -1408,7 +1418,7 @@ def __init__( upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, - processor=AttnAddedKVProcessor(), + processor=processor, ) ) self.attentions = nn.ModuleList(attentions) @@ -2399,6 +2409,11 @@ def __init__( skip_time_act=skip_time_act, ) ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + attentions.append( Attention( query_dim=out_channels, @@ -2411,7 +2426,7 @@ def __init__( upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, - processor=AttnAddedKVProcessor(), + processor=processor, ) ) self.attentions = nn.ModuleList(attentions) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 4c0a4d89dc1e..35ddfcadc3cb 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -8,7 +8,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.attention import Attention -from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor, AttnProcessor +from ...models.attention_processor import ( + AttentionProcessor, + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + AttnProcessor, +) from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel @@ -1545,6 +1550,10 @@ def __init__( attentions = [] for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + attentions.append( Attention( query_dim=in_channels, @@ -1557,7 +1566,7 @@ def __init__( upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, - processor=AttnAddedKVProcessor(), + processor=processor, ) ) resnets.append( diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index 304f5f286830..3cacb0bcad0b 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -421,7 +421,12 @@ class DummyScheduler: def test_attention_slicing_forward_pass(self): test_max_difference = torch_device == "cpu" - self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference) + # Check is relaxed because there is not a torch 2.0 sliced attention added kv processor + expected_max_diff = 1e-2 + + self._test_attention_slicing_forward_pass( + test_max_difference=test_max_difference, expected_max_diff=expected_max_diff + ) # Overriding PipelineTesterMixin::test_inference_batch_single_identical # because UnCLIP undeterminism requires a looser check.