diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 864b042c245a..41baf999999d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -56,7 +56,8 @@ def __init__( bias=False, upcast_attention: bool = False, upcast_softmax: bool = False, - cross_attention_norm: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, out_bias: bool = True, @@ -69,7 +70,6 @@ def __init__( cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax - self.cross_attention_norm = cross_attention_norm self.scale = dim_head**-0.5 if scale_qk else 1.0 @@ -92,8 +92,28 @@ def __init__( else: self.group_norm = None - if cross_attention_norm: + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": self.norm_cross = nn.LayerNorm(cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) @@ -304,6 +324,25 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None) attention_mask = attention_mask.repeat_interleave(head_size, dim=0) return attention_mask + def norm_encoder_hidden_states(self, encoder_hidden_states): + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + class AttnProcessor: def __call__( @@ -321,8 +360,8 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -388,7 +427,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) @@ -416,6 +458,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) @@ -467,8 +514,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -511,8 +558,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -561,7 +608,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query).contiguous() - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) @@ -598,8 +648,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -647,6 +697,11 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 540059b10713..08578c81091e 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -44,6 +44,7 @@ def get_down_block( resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, + cross_attention_norm=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -126,6 +127,7 @@ def get_down_block( skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -223,6 +225,7 @@ def get_up_block( resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, + cross_attention_norm=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -293,6 +296,7 @@ def get_up_block( skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -578,6 +582,7 @@ def __init__( cross_attention_dim=1280, skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() @@ -618,6 +623,7 @@ def __init__( bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, processor=AttnAddedKVProcessor(), ) ) @@ -1361,6 +1367,7 @@ def __init__( add_downsample=True, skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() @@ -1400,6 +1407,7 @@ def __init__( bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, processor=AttnAddedKVProcessor(), ) ) @@ -1580,7 +1588,7 @@ def __init__( temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, - cross_attention_norm=True, + cross_attention_norm="layer_norm", group_size=resnet_group_size, ) ) @@ -2361,6 +2369,7 @@ def __init__( add_upsample=True, skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() resnets = [] @@ -2401,6 +2410,7 @@ def __init__( bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, processor=AttnAddedKVProcessor(), ) ) @@ -2608,7 +2618,7 @@ def __init__( temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, - cross_attention_norm=True, + cross_attention_norm="layer_norm", upcast_attention=upcast_attention, ) ) @@ -2703,7 +2713,7 @@ def __init__( upcast_attention: bool = False, temb_channels: int = 768, # for ada_group_norm add_self_attention: bool = False, - cross_attention_norm: bool = False, + cross_attention_norm: Optional[str] = None, group_size: int = 32, ): super().__init__() @@ -2719,7 +2729,7 @@ def __init__( dropout=dropout, bias=attention_bias, cross_attention_dim=None, - cross_attention_norm=False, + cross_attention_norm=None, ) # 2. Cross-Attn diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 3fb4202ed119..5134e3dae16e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -164,6 +164,7 @@ def __init__( projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, ): super().__init__() @@ -323,6 +324,7 @@ def __init__( resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.down_blocks.append(down_block) @@ -355,6 +357,7 @@ def __init__( resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None @@ -406,6 +409,7 @@ def __init__( resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index e457ad2b3afc..0239c8128171 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -243,8 +243,8 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 063882284754..c6d67c6148d2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -65,8 +65,8 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 51d1c62c926b..acf7ed3ff5f8 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -250,6 +250,7 @@ def __init__( projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, ): super().__init__() @@ -415,6 +416,7 @@ def __init__( resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.down_blocks.append(down_block) @@ -447,6 +449,7 @@ def __init__( resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None @@ -498,6 +501,7 @@ def __init__( resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -1490,6 +1494,7 @@ def __init__( cross_attention_dim=1280, skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() @@ -1530,6 +1535,7 @@ def __init__( bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, processor=AttnAddedKVProcessor(), ) )