From c90c86f57c9dc13abf6c2ea30e886ef47edc56ff Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 3 Aug 2024 22:36:44 +0200 Subject: [PATCH 1/3] apply qk norm in attention processors --- src/diffusers/models/attention_processor.py | 60 ++++++++++----------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 855085c0d933..cdb5ad6654dc 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -192,7 +192,8 @@ def __init__( self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps) else: - raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + allowed_qk_norms = set("layer_norm", "fp32_layer_norm", "layer_norm_across_heads", "rms_norm") + raise ValueError(f"Unexpected value for {qk_norm=}. It must be one of {allowed_qk_norms}") if cross_attention_norm is None: self.norm_cross = None @@ -661,6 +662,13 @@ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> tor return encoder_hidden_states + def maybe_apply_qk_norm(self, query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + return query, key + @torch.no_grad() def fuse_projections(self, fuse=True): device = self.to_q.weight.data.device @@ -1218,11 +1226,8 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim) value = value.view(batch_size, -1, attn.heads, head_dim) - # Apply QK norm. - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + # Apply QK-norm if needed + query, key = attn.maybe_apply_qk_norm(query, key) # Concatenate the projections. if encoder_hidden_states is not None: @@ -1323,10 +1328,8 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + # Apply QK-norm if needed + query, key = attn.maybe_apply_qk_norm(query, key) # Apply RoPE if needed if image_rotary_emb is not None: @@ -1387,10 +1390,8 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + # Apply QK-norm if needed + query, key = attn.maybe_apply_qk_norm(query, key) # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) @@ -1785,6 +1786,9 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # Apply QK-norm if needed + query, key = attn.maybe_apply_qk_norm(query, key) + # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( @@ -1889,10 +1893,8 @@ def __call__( key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + # Apply QK-norm if needed + query, key = attn.maybe_apply_qk_norm(query, key) # Apply RoPE if needed if rotary_emb is not None: @@ -2003,10 +2005,8 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + # Apply QK-norm if needed + query, key = attn.maybe_apply_qk_norm(query, key) # Apply RoPE if needed if image_rotary_emb is not None: @@ -2106,10 +2106,8 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + # Apply QK-norm if needed + query, key = attn.maybe_apply_qk_norm(query, key) # Apply RoPE if needed if image_rotary_emb is not None: @@ -2185,11 +2183,8 @@ def __call__( # Get key-value heads kv_heads = inner_dim // head_dim - # Apply Query-Key Norm if needed - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + # Apply QK-norm if needed + query, key = attn.maybe_apply_qk_norm(query, key) query = query.view(batch_size, -1, attn.heads, head_dim) @@ -2314,6 +2309,9 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # Apply QK-norm if needed + query, key = attn.maybe_apply_qk_norm(query, key) + # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( From 0018fc4c033de2245945d5911a38e974fa979747 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 4 Aug 2024 12:42:40 +0200 Subject: [PATCH 2/3] revert attention processor --- src/diffusers/models/attention_processor.py | 60 +++++++++++---------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cdb5ad6654dc..855085c0d933 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -192,8 +192,7 @@ def __init__( self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps) else: - allowed_qk_norms = set("layer_norm", "fp32_layer_norm", "layer_norm_across_heads", "rms_norm") - raise ValueError(f"Unexpected value for {qk_norm=}. It must be one of {allowed_qk_norms}") + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") if cross_attention_norm is None: self.norm_cross = None @@ -662,13 +661,6 @@ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> tor return encoder_hidden_states - def maybe_apply_qk_norm(self, query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if self.norm_q is not None: - query = self.norm_q(query) - if self.norm_k is not None: - key = self.norm_k(key) - return query, key - @torch.no_grad() def fuse_projections(self, fuse=True): device = self.to_q.weight.data.device @@ -1226,8 +1218,11 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim) value = value.view(batch_size, -1, attn.heads, head_dim) - # Apply QK-norm if needed - query, key = attn.maybe_apply_qk_norm(query, key) + # Apply QK norm. + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) # Concatenate the projections. if encoder_hidden_states is not None: @@ -1328,8 +1323,10 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # Apply QK-norm if needed - query, key = attn.maybe_apply_qk_norm(query, key) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) # Apply RoPE if needed if image_rotary_emb is not None: @@ -1390,8 +1387,10 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # Apply QK-norm if needed - query, key = attn.maybe_apply_qk_norm(query, key) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) @@ -1786,9 +1785,6 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # Apply QK-norm if needed - query, key = attn.maybe_apply_qk_norm(query, key) - # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( @@ -1893,8 +1889,10 @@ def __call__( key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) - # Apply QK-norm if needed - query, key = attn.maybe_apply_qk_norm(query, key) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) # Apply RoPE if needed if rotary_emb is not None: @@ -2005,8 +2003,10 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # Apply QK-norm if needed - query, key = attn.maybe_apply_qk_norm(query, key) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) # Apply RoPE if needed if image_rotary_emb is not None: @@ -2106,8 +2106,10 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # Apply QK-norm if needed - query, key = attn.maybe_apply_qk_norm(query, key) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) # Apply RoPE if needed if image_rotary_emb is not None: @@ -2183,8 +2185,11 @@ def __call__( # Get key-value heads kv_heads = inner_dim // head_dim - # Apply QK-norm if needed - query, key = attn.maybe_apply_qk_norm(query, key) + # Apply Query-Key Norm if needed + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) query = query.view(batch_size, -1, attn.heads, head_dim) @@ -2309,9 +2314,6 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # Apply QK-norm if needed - query, key = attn.maybe_apply_qk_norm(query, key) - # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( From d1d995f6f4fde670a67753cd9c7e326a4a9330d9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 4 Aug 2024 12:44:33 +0200 Subject: [PATCH 3/3] qk-norm in only attention proc 2.0 and fused variant --- src/diffusers/models/attention_processor.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 855085c0d933..30be4ebadf26 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1785,6 +1785,11 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( @@ -2314,6 +2319,11 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention(