@@ -102,93 +102,6 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
102
102
103
103
from typing import Tuple , Union , Optional
104
104
105
- def reshape_for_broadcast (freqs_cis : Union [torch .Tensor , Tuple [torch .Tensor ]], x : torch .Tensor , head_first = False ):
106
- """
107
- Reshape frequency tensor for broadcasting it with another tensor.
108
-
109
- This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
110
- for the purpose of broadcasting the frequency tensor during element-wise operations.
111
-
112
- Args:
113
- freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
114
- x (torch.Tensor): Target tensor for broadcasting compatibility.
115
- head_first (bool): head dimension first (except batch dim) or not.
116
-
117
- Returns:
118
- torch.Tensor: Reshaped frequency tensor.
119
-
120
- Raises:
121
- AssertionError: If the frequency tensor doesn't match the expected shape.
122
- AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
123
- """
124
- ndim = x .ndim
125
- assert 0 <= 1 < ndim
126
-
127
- if isinstance (freqs_cis , tuple ):
128
- # freqs_cis: (cos, sin) in real space
129
- if head_first :
130
- assert freqs_cis [0 ].shape == (x .shape [- 2 ], x .shape [- 1 ]), f'freqs_cis shape { freqs_cis [0 ].shape } does not match x shape { x .shape } '
131
- shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i , d in enumerate (x .shape )]
132
- else :
133
- assert freqs_cis [0 ].shape == (x .shape [1 ], x .shape [- 1 ]), f'freqs_cis shape { freqs_cis [0 ].shape } does not match x shape { x .shape } '
134
- shape = [d if i == 1 or i == ndim - 1 else 1 for i , d in enumerate (x .shape )]
135
- return freqs_cis [0 ].view (* shape ), freqs_cis [1 ].view (* shape )
136
- else :
137
- # freqs_cis: values in complex space
138
- if head_first :
139
- assert freqs_cis .shape == (x .shape [- 2 ], x .shape [- 1 ]), f'freqs_cis shape { freqs_cis .shape } does not match x shape { x .shape } '
140
- shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i , d in enumerate (x .shape )]
141
- else :
142
- assert freqs_cis .shape == (x .shape [1 ], x .shape [- 1 ]), f'freqs_cis shape { freqs_cis .shape } does not match x shape { x .shape } '
143
- shape = [d if i == 1 or i == ndim - 1 else 1 for i , d in enumerate (x .shape )]
144
- return freqs_cis .view (* shape )
145
-
146
-
147
- def rotate_half (x ):
148
- x_real , x_imag = x .float ().reshape (* x .shape [:- 1 ], - 1 , 2 ).unbind (- 1 ) # [B, S, H, D//2]
149
- return torch .stack ([- x_imag , x_real ], dim = - 1 ).flatten (3 )
150
-
151
- def apply_rotary_emb (
152
- xq : torch .Tensor ,
153
- xk : Optional [torch .Tensor ],
154
- freqs_cis : Union [torch .Tensor , Tuple [torch .Tensor ]],
155
- head_first : bool = False ,
156
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
157
- """
158
- Apply rotary embeddings to input tensors using the given frequency tensor.
159
-
160
- This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
161
- frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
162
- is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
163
- returned as real tensors.
164
-
165
- Args:
166
- xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
167
- xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
168
- freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
169
- head_first (bool): head dimension first (except batch dim) or not.
170
-
171
- Returns:
172
- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
173
-
174
- """
175
- xk_out = None
176
- if isinstance (freqs_cis , tuple ):
177
- cos , sin = reshape_for_broadcast (freqs_cis , xq , head_first ) # [S, D]
178
- cos , sin = cos .to (xq .device ), sin .to (xq .device )
179
- xq_out = (xq .float () * cos + rotate_half (xq .float ()) * sin ).type_as (xq )
180
- if xk is not None :
181
- xk_out = (xk .float () * cos + rotate_half (xk .float ()) * sin ).type_as (xk )
182
- else :
183
- xq_ = torch .view_as_complex (xq .float ().reshape (* xq .shape [:- 1 ], - 1 , 2 )) # [B, S, H, D//2]
184
- freqs_cis = reshape_for_broadcast (freqs_cis , xq_ , head_first ).to (xq .device ) # [S, D//2] --> [1, S, 1, D//2]
185
- xq_out = torch .view_as_real (xq_ * freqs_cis ).flatten (3 ).type_as (xq )
186
- if xk is not None :
187
- xk_ = torch .view_as_complex (xk .float ().reshape (* xk .shape [:- 1 ], - 1 , 2 )) # [B, S, H, D//2]
188
- xk_out = torch .view_as_real (xk_ * freqs_cis ).flatten (3 ).type_as (xk )
189
-
190
- return xq_out , xk_out
191
-
192
105
class HunyuanDiTAttentionPool (nn .Module ):
193
106
def __init__ (self , spacial_dim : int , embed_dim : int , num_heads : int , output_dim : int = None ):
194
107
super ().__init__ ()
@@ -223,147 +136,13 @@ def forward(self, x):
223
136
need_weights = False
224
137
)
225
138
return x .squeeze (0 )
226
-
227
- class HunyuanDiTCrossAttention (nn .Module ):
228
- """
229
- Use QK Normalization.
230
- """
231
- def __init__ (self ,
232
- qdim ,
233
- kdim ,
234
- num_heads ,
235
- qkv_bias = True ,
236
- qk_norm = False ,
237
- attn_drop = 0.0 ,
238
- proj_drop = 0.0 ,
239
- device = None ,
240
- dtype = None ,
241
- norm_layer = nn .LayerNorm ,
242
- ):
243
- factory_kwargs = {'device' : device , 'dtype' : dtype }
244
- super ().__init__ ()
245
- self .qdim = qdim
246
- self .kdim = kdim
247
- self .num_heads = num_heads
248
- assert self .qdim % num_heads == 0 , "self.qdim must be divisible by num_heads"
249
- self .head_dim = self .qdim // num_heads
250
- assert self .head_dim % 8 == 0 and self .head_dim <= 128 , "Only support head_dim <= 128 and divisible by 8"
251
- self .scale = self .head_dim ** - 0.5
252
-
253
- self .q_proj = nn .Linear (qdim , qdim , bias = qkv_bias , ** factory_kwargs )
254
- self .kv_proj = nn .Linear (kdim , 2 * qdim , bias = qkv_bias , ** factory_kwargs )
255
-
256
- # TODO: eps should be 1 / 65530 if using fp16
257
- self .q_norm = norm_layer (self .head_dim , elementwise_affine = True , eps = 1e-6 ) if qk_norm else nn .Identity ()
258
- self .k_norm = norm_layer (self .head_dim , elementwise_affine = True , eps = 1e-6 ) if qk_norm else nn .Identity ()
259
- self .attn_drop = nn .Dropout (attn_drop )
260
- self .out_proj = nn .Linear (qdim , qdim , bias = qkv_bias , ** factory_kwargs )
261
- self .proj_drop = nn .Dropout (proj_drop )
262
-
263
- def forward (self , x , y , freqs_cis_img = None ):
264
- """
265
- Parameters
266
- ----------
267
- x: torch.Tensor
268
- (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
269
- y: torch.Tensor
270
- (batch, seqlen2, hidden_dim2)
271
- freqs_cis_img: torch.Tensor
272
- (batch, hidden_dim // 2), RoPE for image
273
- """
274
- b , s1 , c = x .shape # [b, s1, D]
275
- _ , s2 , c = y .shape # [b, s2, 1024]
276
-
277
- q = self .q_proj (x ).view (b , s1 , self .num_heads , self .head_dim ) # [b, s1, h, d]
278
- kv = self .kv_proj (y ).view (b , s2 , 2 , self .num_heads , self .head_dim ) # [b, s2, 2, h, d]
279
- k , v = kv .unbind (dim = 2 ) # [b, s, h, d]
280
- q = self .q_norm (q )
281
- k = self .k_norm (k )
282
-
283
- # Apply RoPE if needed
284
- if freqs_cis_img is not None :
285
- qq , _ = apply_rotary_emb (q , None , freqs_cis_img )
286
- assert qq .shape == q .shape , f'qq: { qq .shape } , q: { q .shape } '
287
- q = qq
288
-
289
- q = q * self .scale
290
- q = q .transpose (- 2 , - 3 ).contiguous () # q -> B, L1, H, C - B, H, L1, C
291
- k = k .permute (0 , 2 , 3 , 1 ).contiguous () # k -> B, L2, H, C - B, H, C, L2
292
- attn = q @ k # attn -> B, H, L1, L2
293
- attn = attn .softmax (dim = - 1 ) # attn -> B, H, L1, L2
294
- attn = self .attn_drop (attn )
295
- x = attn @ v .transpose (- 2 , - 3 ) # v -> B, L2, H, C - B, H, L2, C x-> B, H, L1, C
296
- context = x .transpose (1 , 2 ) # context -> B, H, L1, C - B, L1, H, C
297
-
298
- context = context .contiguous ().view (b , s1 , - 1 )
299
-
300
- out = self .out_proj (context ) # context.reshape - B, L1, -1
301
- out = self .proj_drop (out )
302
-
303
- out_tuple = (out ,)
304
-
305
- return out_tuple
306
-
307
-
308
- class HunyuanDiTAttention (nn .Module ):
309
- """
310
- We rename some layer names to align with flash attention
311
- """
312
- def __init__ (self , dim , num_heads , qkv_bias = True , qk_norm = False , attn_drop = 0. , proj_drop = 0. ,
313
- norm_layer = nn .LayerNorm ,
314
- ):
315
- super ().__init__ ()
316
- self .dim = dim
317
- self .num_heads = num_heads
318
- assert self .dim % num_heads == 0 , 'dim should be divisible by num_heads'
319
- self .head_dim = self .dim // num_heads
320
- # This assertion is aligned with flash attention
321
- assert self .head_dim % 8 == 0 and self .head_dim <= 128 , "Only support head_dim <= 128 and divisible by 8"
322
- self .scale = self .head_dim ** - 0.5
323
-
324
- # qkv --> Wqkv
325
- self .Wqkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
326
- # TODO: eps should be 1 / 65530 if using fp16
327
- self .q_norm = norm_layer (self .head_dim , elementwise_affine = True , eps = 1e-6 ) if qk_norm else nn .Identity ()
328
- self .k_norm = norm_layer (self .head_dim , elementwise_affine = True , eps = 1e-6 ) if qk_norm else nn .Identity ()
329
- self .attn_drop = nn .Dropout (attn_drop )
330
- self .out_proj = nn .Linear (dim , dim )
331
- self .proj_drop = nn .Dropout (proj_drop )
332
-
333
- def forward (self , x , freqs_cis_img = None ):
334
- B , N , C = x .shape
335
- qkv = self .Wqkv (x ).reshape (B , N , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 ) # [3, b, h, s, d]
336
- q , k , v = qkv .unbind (0 ) # [b, h, s, d]
337
- q = self .q_norm (q ) # [b, h, s, d]
338
- k = self .k_norm (k ) # [b, h, s, d]
339
-
340
- # Apply RoPE if needed
341
- if freqs_cis_img is not None :
342
- qq , kk = apply_rotary_emb (q , k , freqs_cis_img , head_first = True )
343
- assert qq .shape == q .shape and kk .shape == k .shape , \
344
- f'qq: { qq .shape } , q: { q .shape } , kk: { kk .shape } , k: { k .shape } '
345
- q , k = qq , kk
346
-
347
- q = q * self .scale
348
- attn = q @ k .transpose (- 2 , - 1 ) # [b, h, s, d] @ [b, h, d, s]
349
- attn = attn .softmax (dim = - 1 ) # [b, h, s, s]
350
- attn = self .attn_drop (attn )
351
- x = attn @ v # [b, h, s, d]
352
-
353
- x = x .transpose (1 , 2 ).reshape (B , N , C ) # [b, s, h, d]
354
- x = self .out_proj (x )
355
- x = self .proj_drop (x )
356
-
357
- out_tuple = (x ,)
358
-
359
- return out_tuple
360
139
### ==== end ====
361
140
141
+
362
142
@maybe_allow_in_graph
363
143
class HunyuanDiTBlock (nn .Module ):
364
144
r"""
365
145
HunyuanDiT Transformer block. Allow skip connection and QKNorm
366
-
367
146
Parameters:
368
147
dim (`int`): The number of channels in the input and output.
369
148
num_attention_heads (`int`): The number of heads to use for multi-head attention.
@@ -416,19 +195,36 @@ def __init__(
416
195
# 1. Self-Attn
417
196
self .norm1 = FP32_Layernorm (dim , elementwise_affine = norm_elementwise_affine , eps = norm_eps )
418
197
419
- self .attn1 = HunyuanDiTAttention (dim , num_heads = num_attention_heads , qkv_bias = True , qk_norm = qk_norm )
198
+ from .attention_processor import HunyuanAttnProcessor2_0
199
+ self .attn1 = Attention (
200
+ query_dim = dim ,
201
+ cross_attention_dim = dim ,
202
+ dim_head = dim // num_attention_heads ,
203
+ heads = num_attention_heads ,
204
+ qk_norm = "layer_norm" if qk_norm else None ,
205
+ eps = 1e-6 ,
206
+ bias = True ,
207
+ processor = HunyuanAttnProcessor2_0 (),
208
+ )
420
209
421
210
# 2. Cross-Attn
422
211
self .norm3 = FP32_Layernorm (dim , norm_eps , norm_elementwise_affine )
423
212
424
- self .attn2 = HunyuanDiTCrossAttention (dim , text_dim , num_heads = num_attention_heads , qkv_bias = True , qk_norm = qk_norm )
425
-
213
+ self .attn2 = Attention (
214
+ query_dim = dim ,
215
+ cross_attention_dim = text_dim ,
216
+ dim_head = dim // num_attention_heads ,
217
+ heads = num_attention_heads ,
218
+ qk_norm = "layer_norm" if qk_norm else None ,
219
+ eps = 1e-6 ,
220
+ bias = True ,
221
+ processor = HunyuanAttnProcessor2_0 (),
222
+ )
426
223
# 3. Feed-forward
427
224
self .norm2 = FP32_Layernorm (dim , norm_eps , norm_elementwise_affine )
428
225
429
- ### NOTE: do not switch norm2 and norm3, otherwise will load wrong key when using pretrained model!
226
+ ### TODO: switch norm2 and norm3 in the state dict
430
227
431
- #print('mlp hidden dim:', ff_inner_dim)
432
228
self .ff = FeedForward (
433
229
dim ,
434
230
dropout = dropout , ### 0.0
@@ -475,28 +271,27 @@ def forward(
475
271
cat = torch .cat ([hidden_states , skip ], dim = - 1 )
476
272
cat = self .skip_norm (cat )
477
273
hidden_states = self .skip_linear (cat )
478
-
479
- #print('x:', hidden_states[0])
274
+
480
275
# 1. Self-Attention
481
276
norm_hidden_states = self .norm1 (hidden_states ) ### checked: self.norm1 is correct
482
277
shift_msa = self .default_modulation (timestep ).unsqueeze (dim = 1 )
483
- attn_inputs = (norm_hidden_states + shift_msa , freq_cis_img ,)
484
- attn_output = self .attn1 (* attn_inputs )[0 ]
278
+ attn_output = self .attn1 (
279
+ norm_hidden_states + shift_msa ,
280
+ temb = freq_cis_img ,
281
+ )
485
282
hidden_states = hidden_states + attn_output
486
- #print('x:', hidden_states[0])
487
283
488
284
# 2. Cross-Attention
489
- cross_inputs = (
490
- self .norm3 (hidden_states ), encoder_hidden_states , freq_cis_img
285
+ hidden_states = hidden_states + self .attn2 (
286
+ self .norm3 (hidden_states ),
287
+ encoder_hidden_states = encoder_hidden_states ,
288
+ temb = freq_cis_img ,
491
289
)
492
- hidden_states = hidden_states + self .attn2 (* cross_inputs )[0 ]
493
- #print('x:', hidden_states[0])
494
290
495
- # FFN Layer ### NOTE: do not switch norm2 and norm3, otherwise will load wrong key when using pretrained model!
291
+ # FFN Layer ### TODO: switch norm2 and norm3 in the state dict
496
292
mlp_inputs = self .norm2 (hidden_states )
497
293
hidden_states = hidden_states + self .ff (mlp_inputs )
498
- #print('x:', hidden_states[0])
499
-
294
+
500
295
return hidden_states
501
296
502
297
@maybe_allow_in_graph
0 commit comments