@@ -103,6 +103,7 @@ def __init__(
103
103
upcast_softmax : bool = False ,
104
104
cross_attention_norm : Optional [str ] = None ,
105
105
cross_attention_norm_num_groups : int = 32 ,
106
+ qk_norm : Optional [str ] = None ,
106
107
added_kv_proj_dim : Optional [int ] = None ,
107
108
norm_num_groups : Optional [int ] = None ,
108
109
spatial_norm_dim : Optional [int ] = None ,
@@ -161,6 +162,15 @@ def __init__(
161
162
else :
162
163
self .spatial_norm = None
163
164
165
+ if qk_norm is None :
166
+ self .norm_q = None
167
+ self .norm_k = None
168
+ elif qk_norm == "layer_norm" :
169
+ self .norm_q = nn .LayerNorm (dim_head , eps = eps )
170
+ self .norm_k = nn .LayerNorm (dim_head , eps = eps )
171
+ else :
172
+ raise ValueError (f"unknown qk_norm: { qk_norm } . Should be None or 'layer_norm'" )
173
+
164
174
if cross_attention_norm is None :
165
175
self .norm_cross = None
166
176
elif cross_attention_norm == "layer_norm" :
@@ -1426,6 +1436,104 @@ def __call__(
1426
1436
return hidden_states
1427
1437
1428
1438
1439
+ class HunyuanAttnProcessor2_0 :
1440
+ r"""
1441
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
1442
+ used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
1443
+ """
1444
+
1445
+ def __init__ (self ):
1446
+ if not hasattr (F , "scaled_dot_product_attention" ):
1447
+ raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
1448
+
1449
+ def __call__ (
1450
+ self ,
1451
+ attn : Attention ,
1452
+ hidden_states : torch .Tensor ,
1453
+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
1454
+ attention_mask : Optional [torch .Tensor ] = None ,
1455
+ temb : Optional [torch .Tensor ] = None ,
1456
+ image_rotary_emb : Optional [torch .Tensor ] = None ,
1457
+ ) -> torch .Tensor :
1458
+ from .embeddings import apply_rotary_emb
1459
+
1460
+ residual = hidden_states
1461
+ if attn .spatial_norm is not None :
1462
+ hidden_states = attn .spatial_norm (hidden_states , temb )
1463
+
1464
+ input_ndim = hidden_states .ndim
1465
+
1466
+ if input_ndim == 4 :
1467
+ batch_size , channel , height , width = hidden_states .shape
1468
+ hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1469
+
1470
+ batch_size , sequence_length , _ = (
1471
+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
1472
+ )
1473
+
1474
+ if attention_mask is not None :
1475
+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
1476
+ # scaled_dot_product_attention expects attention_mask shape to be
1477
+ # (batch, heads, source_length, target_length)
1478
+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
1479
+
1480
+ if attn .group_norm is not None :
1481
+ hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
1482
+
1483
+ query = attn .to_q (hidden_states )
1484
+
1485
+ if encoder_hidden_states is None :
1486
+ encoder_hidden_states = hidden_states
1487
+ elif attn .norm_cross :
1488
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
1489
+
1490
+ key = attn .to_k (encoder_hidden_states )
1491
+ value = attn .to_v (encoder_hidden_states )
1492
+
1493
+ inner_dim = key .shape [- 1 ]
1494
+ head_dim = inner_dim // attn .heads
1495
+
1496
+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1497
+
1498
+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1499
+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1500
+
1501
+ if attn .norm_q is not None :
1502
+ query = attn .norm_q (query )
1503
+ if attn .norm_k is not None :
1504
+ key = attn .norm_k (key )
1505
+
1506
+ # Apply RoPE if needed
1507
+ if image_rotary_emb is not None :
1508
+ query = apply_rotary_emb (query , image_rotary_emb )
1509
+ if not attn .is_cross_attention :
1510
+ key = apply_rotary_emb (key , image_rotary_emb )
1511
+
1512
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1513
+ # TODO: add support for attn.scale when we move to Torch 2.1
1514
+ hidden_states = F .scaled_dot_product_attention (
1515
+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
1516
+ )
1517
+
1518
+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1519
+ hidden_states = hidden_states .to (query .dtype )
1520
+
1521
+ # linear proj
1522
+ hidden_states = attn .to_out [0 ](hidden_states )
1523
+ # dropout
1524
+ hidden_states = attn .to_out [1 ](hidden_states )
1525
+
1526
+ if input_ndim == 4 :
1527
+ hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1528
+
1529
+ if attn .residual_connection :
1530
+ hidden_states = hidden_states + residual
1531
+
1532
+ hidden_states = hidden_states / attn .rescale_output_factor
1533
+
1534
+ return hidden_states
1535
+
1536
+
1429
1537
class FusedAttnProcessor2_0 :
1430
1538
r"""
1431
1539
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
0 commit comments