From e35de4b707e03f5ba2511ed979f06acc51bf1e63 Mon Sep 17 00:00:00 2001 From: William Berman Date: Sat, 8 Apr 2023 13:21:00 -0700 Subject: [PATCH] add group norm type to attention processor cross attention norm This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states` prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten move norm_cross defined check to outside norm_encoder_hidden_states add missing attn.norm_cross check --- src/diffusers/models/attention_processor.py | 81 ++++++++++++++++--- src/diffusers/models/unet_2d_blocks.py | 18 ++++- src/diffusers/models/unet_2d_condition.py | 4 + .../pipeline_stable_diffusion_pix2pix_zero.py | 4 +- .../pipeline_stable_diffusion_sag.py | 4 +- .../versatile_diffusion/modeling_text_unet.py | 6 ++ 6 files changed, 96 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 864b042c245a..41baf999999d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -56,7 +56,8 @@ def __init__( bias=False, upcast_attention: bool = False, upcast_softmax: bool = False, - cross_attention_norm: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, out_bias: bool = True, @@ -69,7 +70,6 @@ def __init__( cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax - self.cross_attention_norm = cross_attention_norm self.scale = dim_head**-0.5 if scale_qk else 1.0 @@ -92,8 +92,28 @@ def __init__( else: self.group_norm = None - if cross_attention_norm: + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": self.norm_cross = nn.LayerNorm(cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) @@ -304,6 +324,25 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None) attention_mask = attention_mask.repeat_interleave(head_size, dim=0) return attention_mask + def norm_encoder_hidden_states(self, encoder_hidden_states): + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + class AttnProcessor: def __call__( @@ -321,8 +360,8 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -388,7 +427,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) @@ -416,6 +458,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) @@ -467,8 +514,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -511,8 +558,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -561,7 +608,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query).contiguous() - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) @@ -598,8 +648,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -647,6 +697,11 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 540059b10713..08578c81091e 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -44,6 +44,7 @@ def get_down_block( resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, + cross_attention_norm=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -126,6 +127,7 @@ def get_down_block( skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -223,6 +225,7 @@ def get_up_block( resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, + cross_attention_norm=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -293,6 +296,7 @@ def get_up_block( skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -578,6 +582,7 @@ def __init__( cross_attention_dim=1280, skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() @@ -618,6 +623,7 @@ def __init__( bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, processor=AttnAddedKVProcessor(), ) ) @@ -1361,6 +1367,7 @@ def __init__( add_downsample=True, skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() @@ -1400,6 +1407,7 @@ def __init__( bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, processor=AttnAddedKVProcessor(), ) ) @@ -1580,7 +1588,7 @@ def __init__( temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, - cross_attention_norm=True, + cross_attention_norm="layer_norm", group_size=resnet_group_size, ) ) @@ -2361,6 +2369,7 @@ def __init__( add_upsample=True, skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() resnets = [] @@ -2401,6 +2410,7 @@ def __init__( bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, processor=AttnAddedKVProcessor(), ) ) @@ -2608,7 +2618,7 @@ def __init__( temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, - cross_attention_norm=True, + cross_attention_norm="layer_norm", upcast_attention=upcast_attention, ) ) @@ -2703,7 +2713,7 @@ def __init__( upcast_attention: bool = False, temb_channels: int = 768, # for ada_group_norm add_self_attention: bool = False, - cross_attention_norm: bool = False, + cross_attention_norm: Optional[str] = None, group_size: int = 32, ): super().__init__() @@ -2719,7 +2729,7 @@ def __init__( dropout=dropout, bias=attention_bias, cross_attention_dim=None, - cross_attention_norm=False, + cross_attention_norm=None, ) # 2. Cross-Attn diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 3fb4202ed119..5134e3dae16e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -164,6 +164,7 @@ def __init__( projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, ): super().__init__() @@ -323,6 +324,7 @@ def __init__( resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.down_blocks.append(down_block) @@ -355,6 +357,7 @@ def __init__( resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None @@ -406,6 +409,7 @@ def __init__( resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index e457ad2b3afc..0239c8128171 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -243,8 +243,8 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 063882284754..c6d67c6148d2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -65,8 +65,8 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 51d1c62c926b..acf7ed3ff5f8 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -250,6 +250,7 @@ def __init__( projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, ): super().__init__() @@ -415,6 +416,7 @@ def __init__( resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.down_blocks.append(down_block) @@ -447,6 +449,7 @@ def __init__( resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None @@ -498,6 +501,7 @@ def __init__( resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -1490,6 +1494,7 @@ def __init__( cross_attention_dim=1280, skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() @@ -1530,6 +1535,7 @@ def __init__( bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, processor=AttnAddedKVProcessor(), ) )