@@ -84,214 +84,6 @@ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
84
84
85
85
return x
86
86
87
- ### TODO: XCLiu: some ugly helper functions, please clean later
88
- ### ==== begin ====
89
- def modulate (x , shift , scale ):
90
- return x * (1 + scale .unsqueeze (1 )) + shift .unsqueeze (1 )
91
-
92
- class FP32_Layernorm (nn .LayerNorm ):
93
- def forward (self , inputs : torch .Tensor ) -> torch .Tensor :
94
- origin_dtype = inputs .dtype
95
- return F .layer_norm (inputs .float (), self .normalized_shape , self .weight .float (), self .bias .float (),
96
- self .eps ).to (origin_dtype )
97
-
98
-
99
- class FP32_SiLU (nn .SiLU ):
100
- def forward (self , inputs : torch .Tensor ) -> torch .Tensor :
101
- return torch .nn .functional .silu (inputs .float (), inplace = False ).to (inputs .dtype )
102
-
103
- from typing import Tuple , Union , Optional
104
-
105
- class HunyuanDiTAttentionPool (nn .Module ):
106
- def __init__ (self , spacial_dim : int , embed_dim : int , num_heads : int , output_dim : int = None ):
107
- super ().__init__ ()
108
- self .positional_embedding = nn .Parameter (torch .randn (spacial_dim + 1 , embed_dim ) / embed_dim ** 0.5 )
109
- self .k_proj = nn .Linear (embed_dim , embed_dim )
110
- self .q_proj = nn .Linear (embed_dim , embed_dim )
111
- self .v_proj = nn .Linear (embed_dim , embed_dim )
112
- self .c_proj = nn .Linear (embed_dim , output_dim or embed_dim )
113
- self .num_heads = num_heads
114
-
115
- def forward (self , x ):
116
- x = x .permute (1 , 0 , 2 ) # NLC -> LNC
117
- x = torch .cat ([x .mean (dim = 0 , keepdim = True ), x ], dim = 0 ) # (L+1)NC
118
- x = x + self .positional_embedding [:, None , :].to (x .dtype ) # (L+1)NC
119
- x , _ = F .multi_head_attention_forward (
120
- query = x [:1 ], key = x , value = x ,
121
- embed_dim_to_check = x .shape [- 1 ],
122
- num_heads = self .num_heads ,
123
- q_proj_weight = self .q_proj .weight ,
124
- k_proj_weight = self .k_proj .weight ,
125
- v_proj_weight = self .v_proj .weight ,
126
- in_proj_weight = None ,
127
- in_proj_bias = torch .cat ([self .q_proj .bias , self .k_proj .bias , self .v_proj .bias ]),
128
- bias_k = None ,
129
- bias_v = None ,
130
- add_zero_attn = False ,
131
- dropout_p = 0 ,
132
- out_proj_weight = self .c_proj .weight ,
133
- out_proj_bias = self .c_proj .bias ,
134
- use_separate_proj_weight = True ,
135
- training = self .training ,
136
- need_weights = False
137
- )
138
- return x .squeeze (0 )
139
- ### ==== end ====
140
-
141
-
142
- @maybe_allow_in_graph
143
- class HunyuanDiTBlock (nn .Module ):
144
- r"""
145
- HunyuanDiT Transformer block. Allow skip connection and QKNorm
146
- Parameters:
147
- dim (`int`): The number of channels in the input and output.
148
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
149
- attention_head_dim (`int`): The number of channels in each head.
150
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
151
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
152
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
153
- num_embeds_ada_norm (:
154
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
155
- attention_bias (:
156
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
157
- only_cross_attention (`bool`, *optional*):
158
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
159
- double_self_attention (`bool`, *optional*):
160
- Whether to use two self-attention layers. In this case no cross attention layers are used.
161
- upcast_attention (`bool`, *optional*):
162
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
163
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
164
- Whether to use learnable elementwise affine parameters for normalization.
165
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
166
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
167
- final_dropout (`bool` *optional*, defaults to False):
168
- Whether to apply a final dropout after the last feed-forward layer.
169
- attention_type (`str`, *optional*, defaults to `"default"`):
170
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
171
- positional_embeddings (`str`, *optional*, defaults to `None`):
172
- The type of positional embeddings to apply to.
173
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
174
- The maximum number of positional embeddings to apply.
175
- """
176
-
177
- def __init__ (
178
- self ,
179
- dim : int ,
180
- num_attention_heads : int ,
181
- text_dim : int = 1024 ,
182
- dropout = 0.0 ,
183
- activation_fn : str = "geglu" ,
184
- norm_elementwise_affine : bool = True ,
185
- norm_eps : float = 1e-6 ,
186
- final_dropout : bool = False ,
187
- ff_inner_dim : Optional [int ] = None ,
188
- ff_bias : bool = True ,
189
- skip : bool = False ,
190
- qk_norm : bool = True ,
191
- ):
192
- super ().__init__ ()
193
-
194
- # Define 3 blocks. Each block has its own normalization layer.
195
- # NOTE: when new version comes, chech norm2 and norm 3
196
- # 1. Self-Attn
197
- self .norm1 = FP32_Layernorm (dim , elementwise_affine = norm_elementwise_affine , eps = norm_eps )
198
-
199
- from .attention_processor import HunyuanAttnProcessor2_0
200
- self .attn1 = Attention (
201
- query_dim = dim ,
202
- cross_attention_dim = dim ,
203
- dim_head = dim // num_attention_heads ,
204
- heads = num_attention_heads ,
205
- qk_norm = "layer_norm" if qk_norm else None ,
206
- eps = 1e-6 ,
207
- bias = True ,
208
- processor = HunyuanAttnProcessor2_0 (),
209
- )
210
-
211
- # 2. Cross-Attn
212
- self .norm2 = FP32_Layernorm (dim , norm_eps , norm_elementwise_affine )
213
-
214
- self .attn2 = Attention (
215
- query_dim = dim ,
216
- cross_attention_dim = text_dim ,
217
- dim_head = dim // num_attention_heads ,
218
- heads = num_attention_heads ,
219
- qk_norm = "layer_norm" if qk_norm else None ,
220
- eps = 1e-6 ,
221
- bias = True ,
222
- processor = HunyuanAttnProcessor2_0 (),
223
- )
224
- # 3. Feed-forward
225
- self .norm3 = FP32_Layernorm (dim , norm_eps , norm_elementwise_affine )
226
-
227
- self .ff = FeedForward (
228
- dim ,
229
- dropout = dropout , ### 0.0
230
- activation_fn = activation_fn , ### approx GeLU
231
- final_dropout = final_dropout , ### 0.0
232
- inner_dim = ff_inner_dim , ### int(dim * mlp_ratio)
233
- bias = ff_bias ,
234
- )
235
-
236
- # 4. Skip Connection
237
- if skip :
238
- self .skip_norm = FP32_Layernorm (2 * dim , norm_eps , elementwise_affine = True )
239
- self .skip_linear = nn .Linear (2 * dim , dim )
240
- else :
241
- self .skip_linear = None
242
-
243
- # 5. SDXL-style modulation with add
244
- self .default_modulation = nn .Sequential (
245
- FP32_SiLU (),
246
- nn .Linear (dim , dim , bias = True )
247
- )
248
-
249
- # let chunk size default to None
250
- self ._chunk_size = None
251
- self ._chunk_dim = 0
252
-
253
- def set_chunk_feed_forward (self , chunk_size : Optional [int ], dim : int = 0 ):
254
- # Sets chunk feed-forward
255
- self ._chunk_size = chunk_size
256
- self ._chunk_dim = dim
257
-
258
- def forward (
259
- self ,
260
- hidden_states : torch .Tensor ,
261
- encoder_hidden_states : Optional [torch .Tensor ] = None ,
262
- timestep : Optional [torch .LongTensor ] = None ,
263
- freq_cis_img = None ,
264
- skip = None
265
- ) -> torch .Tensor :
266
-
267
- # Notice that normalization is always applied before the real computation in the following blocks.
268
- # 0. Long Skip Connection
269
- if self .skip_linear is not None :
270
- cat = torch .cat ([hidden_states , skip ], dim = - 1 )
271
- cat = self .skip_norm (cat )
272
- hidden_states = self .skip_linear (cat )
273
-
274
- # 1. Self-Attention
275
- norm_hidden_states = self .norm1 (hidden_states ) ### checked: self.norm1 is correct
276
- shift_msa = self .default_modulation (timestep ).unsqueeze (dim = 1 )
277
- attn_output = self .attn1 (
278
- norm_hidden_states + shift_msa ,
279
- temb = freq_cis_img ,
280
- )
281
- hidden_states = hidden_states + attn_output
282
-
283
- # 2. Cross-Attention
284
- hidden_states = hidden_states + self .attn2 (
285
- self .norm2 (hidden_states ),
286
- encoder_hidden_states = encoder_hidden_states ,
287
- temb = freq_cis_img ,
288
- )
289
-
290
- # FFN Layer ### TODO: switch norm2 and norm3 in the state dict
291
- mlp_inputs = self .norm3 (hidden_states )
292
- hidden_states = hidden_states + self .ff (mlp_inputs )
293
-
294
- return hidden_states
295
87
296
88
@maybe_allow_in_graph
297
89
class BasicTransformerBlock (nn .Module ):
0 commit comments