Skip to content

Commit b0e0da2

Browse files
yiyixuxuyiyixuxu
authored andcommitted
first draft
refactor rotary embedding + move dit block to dit model file
1 parent 5b00b4b commit b0e0da2

File tree

12 files changed

+748
-916
lines changed

12 files changed

+748
-916
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
"ConsistencyDecoderVAE",
8383
"ControlNetModel",
8484
"ControlNetXSAdapter",
85-
"HunyuanDiT2DModel",
85+
"HunyuanDiT2DModel",
8686
"I2VGenXLUNet",
8787
"Kandinsky3UNet",
8888
"ModelMixin",

src/diffusers/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@
3737
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
3838
_import_structure["embeddings"] = ["ImageProjection"]
3939
_import_structure["modeling_utils"] = ["ModelMixin"]
40+
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
4041
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
4142
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
4243
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
43-
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
4444
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
4545
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
4646
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
@@ -75,8 +75,8 @@
7575
from .embeddings import ImageProjection
7676
from .modeling_utils import ModelMixin
7777
from .transformers import (
78-
HunyuanDiT2DModel,
7978
DualTransformer2DModel,
79+
HunyuanDiT2DModel,
8080
PriorTransformer,
8181
T5FilmDecoder,
8282
Transformer2DModel,

src/diffusers/models/attention.py

Lines changed: 0 additions & 208 deletions
Original file line numberDiff line numberDiff line change
@@ -84,214 +84,6 @@ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
8484

8585
return x
8686

87-
### TODO: XCLiu: some ugly helper functions, please clean later
88-
### ==== begin ====
89-
def modulate(x, shift, scale):
90-
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
91-
92-
class FP32_Layernorm(nn.LayerNorm):
93-
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
94-
origin_dtype = inputs.dtype
95-
return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(),
96-
self.eps).to(origin_dtype)
97-
98-
99-
class FP32_SiLU(nn.SiLU):
100-
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
101-
return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype)
102-
103-
from typing import Tuple, Union, Optional
104-
105-
class HunyuanDiTAttentionPool(nn.Module):
106-
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
107-
super().__init__()
108-
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
109-
self.k_proj = nn.Linear(embed_dim, embed_dim)
110-
self.q_proj = nn.Linear(embed_dim, embed_dim)
111-
self.v_proj = nn.Linear(embed_dim, embed_dim)
112-
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
113-
self.num_heads = num_heads
114-
115-
def forward(self, x):
116-
x = x.permute(1, 0, 2) # NLC -> LNC
117-
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
118-
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
119-
x, _ = F.multi_head_attention_forward(
120-
query=x[:1], key=x, value=x,
121-
embed_dim_to_check=x.shape[-1],
122-
num_heads=self.num_heads,
123-
q_proj_weight=self.q_proj.weight,
124-
k_proj_weight=self.k_proj.weight,
125-
v_proj_weight=self.v_proj.weight,
126-
in_proj_weight=None,
127-
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
128-
bias_k=None,
129-
bias_v=None,
130-
add_zero_attn=False,
131-
dropout_p=0,
132-
out_proj_weight=self.c_proj.weight,
133-
out_proj_bias=self.c_proj.bias,
134-
use_separate_proj_weight=True,
135-
training=self.training,
136-
need_weights=False
137-
)
138-
return x.squeeze(0)
139-
### ==== end ====
140-
141-
142-
@maybe_allow_in_graph
143-
class HunyuanDiTBlock(nn.Module):
144-
r"""
145-
HunyuanDiT Transformer block. Allow skip connection and QKNorm
146-
Parameters:
147-
dim (`int`): The number of channels in the input and output.
148-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
149-
attention_head_dim (`int`): The number of channels in each head.
150-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
151-
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
152-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
153-
num_embeds_ada_norm (:
154-
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
155-
attention_bias (:
156-
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
157-
only_cross_attention (`bool`, *optional*):
158-
Whether to use only cross-attention layers. In this case two cross attention layers are used.
159-
double_self_attention (`bool`, *optional*):
160-
Whether to use two self-attention layers. In this case no cross attention layers are used.
161-
upcast_attention (`bool`, *optional*):
162-
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
163-
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
164-
Whether to use learnable elementwise affine parameters for normalization.
165-
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
166-
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
167-
final_dropout (`bool` *optional*, defaults to False):
168-
Whether to apply a final dropout after the last feed-forward layer.
169-
attention_type (`str`, *optional*, defaults to `"default"`):
170-
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
171-
positional_embeddings (`str`, *optional*, defaults to `None`):
172-
The type of positional embeddings to apply to.
173-
num_positional_embeddings (`int`, *optional*, defaults to `None`):
174-
The maximum number of positional embeddings to apply.
175-
"""
176-
177-
def __init__(
178-
self,
179-
dim: int,
180-
num_attention_heads: int,
181-
text_dim: int=1024,
182-
dropout=0.0,
183-
activation_fn: str = "geglu",
184-
norm_elementwise_affine: bool = True,
185-
norm_eps: float = 1e-6,
186-
final_dropout: bool = False,
187-
ff_inner_dim: Optional[int] = None,
188-
ff_bias: bool = True,
189-
skip: bool = False,
190-
qk_norm: bool = True,
191-
):
192-
super().__init__()
193-
194-
# Define 3 blocks. Each block has its own normalization layer.
195-
# NOTE: when new version comes, chech norm2 and norm 3
196-
# 1. Self-Attn
197-
self.norm1 = FP32_Layernorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
198-
199-
from .attention_processor import HunyuanAttnProcessor2_0
200-
self.attn1 = Attention(
201-
query_dim=dim,
202-
cross_attention_dim=dim,
203-
dim_head = dim //num_attention_heads,
204-
heads = num_attention_heads,
205-
qk_norm="layer_norm" if qk_norm else None,
206-
eps=1e-6,
207-
bias=True,
208-
processor= HunyuanAttnProcessor2_0(),
209-
)
210-
211-
# 2. Cross-Attn
212-
self.norm2 = FP32_Layernorm(dim, norm_eps, norm_elementwise_affine)
213-
214-
self.attn2 = Attention(
215-
query_dim=dim,
216-
cross_attention_dim=text_dim,
217-
dim_head = dim // num_attention_heads,
218-
heads = num_attention_heads,
219-
qk_norm="layer_norm" if qk_norm else None,
220-
eps=1e-6,
221-
bias=True,
222-
processor= HunyuanAttnProcessor2_0(),
223-
)
224-
# 3. Feed-forward
225-
self.norm3 = FP32_Layernorm(dim, norm_eps, norm_elementwise_affine)
226-
227-
self.ff = FeedForward(
228-
dim,
229-
dropout=dropout, ### 0.0
230-
activation_fn=activation_fn, ### approx GeLU
231-
final_dropout=final_dropout, ### 0.0
232-
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
233-
bias=ff_bias,
234-
)
235-
236-
# 4. Skip Connection
237-
if skip:
238-
self.skip_norm = FP32_Layernorm(2 * dim, norm_eps, elementwise_affine=True)
239-
self.skip_linear = nn.Linear(2 * dim, dim)
240-
else:
241-
self.skip_linear = None
242-
243-
# 5. SDXL-style modulation with add
244-
self.default_modulation = nn.Sequential(
245-
FP32_SiLU(),
246-
nn.Linear(dim, dim, bias=True)
247-
)
248-
249-
# let chunk size default to None
250-
self._chunk_size = None
251-
self._chunk_dim = 0
252-
253-
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
254-
# Sets chunk feed-forward
255-
self._chunk_size = chunk_size
256-
self._chunk_dim = dim
257-
258-
def forward(
259-
self,
260-
hidden_states: torch.Tensor,
261-
encoder_hidden_states: Optional[torch.Tensor] = None,
262-
timestep: Optional[torch.LongTensor] = None,
263-
freq_cis_img = None,
264-
skip=None
265-
) -> torch.Tensor:
266-
267-
# Notice that normalization is always applied before the real computation in the following blocks.
268-
# 0. Long Skip Connection
269-
if self.skip_linear is not None:
270-
cat = torch.cat([hidden_states, skip], dim=-1)
271-
cat = self.skip_norm(cat)
272-
hidden_states = self.skip_linear(cat)
273-
274-
# 1. Self-Attention
275-
norm_hidden_states = self.norm1(hidden_states) ### checked: self.norm1 is correct
276-
shift_msa = self.default_modulation(timestep).unsqueeze(dim=1)
277-
attn_output = self.attn1(
278-
norm_hidden_states + shift_msa,
279-
temb = freq_cis_img,
280-
)
281-
hidden_states = hidden_states + attn_output
282-
283-
# 2. Cross-Attention
284-
hidden_states = hidden_states + self.attn2(
285-
self.norm2(hidden_states),
286-
encoder_hidden_states = encoder_hidden_states,
287-
temb = freq_cis_img,
288-
)
289-
290-
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
291-
mlp_inputs = self.norm3(hidden_states)
292-
hidden_states = hidden_states + self.ff(mlp_inputs)
293-
294-
return hidden_states
29587

29688
@maybe_allow_in_graph
29789
class BasicTransformerBlock(nn.Module):

src/diffusers/models/attention_processor.py

Lines changed: 9 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(
161161
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
162162
else:
163163
self.spatial_norm = None
164-
164+
165165
if qk_norm is None:
166166
self.norm_q = None
167167
self.norm_k = None
@@ -1435,6 +1435,7 @@ def __call__(
14351435

14361436
return hidden_states
14371437

1438+
14381439
class HunyuanAttnProcessor2_0:
14391440
r"""
14401441
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@@ -1451,7 +1452,9 @@ def __call__(
14511452
encoder_hidden_states: Optional[torch.Tensor] = None,
14521453
attention_mask: Optional[torch.Tensor] = None,
14531454
temb: Optional[torch.Tensor] = None,
1455+
image_rotary_emb: Optional[torch.Tensor] = None,
14541456
) -> torch.Tensor:
1457+
from .embeddings import apply_rotary_emb
14551458

14561459
residual = hidden_states
14571460
if attn.spatial_norm is not None:
@@ -1478,10 +1481,8 @@ def __call__(
14781481

14791482
query = attn.to_q(hidden_states)
14801483

1481-
apply_rotary_emb_on_key = False
14821484
if encoder_hidden_states is None:
14831485
encoder_hidden_states = hidden_states
1484-
apply_rotary_emb_on_key = True
14851486
elif attn.norm_cross:
14861487
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
14871488

@@ -1502,16 +1503,10 @@ def __call__(
15021503
key = attn.norm_k(key)
15031504

15041505
# Apply RoPE if needed
1505-
if temb is not None:
1506-
if apply_rotary_emb_on_key:
1507-
qq, kk = apply_rotary_emb(query, key, temb, head_first=True)
1508-
assert qq.shape == query.shape and kk.shape == key.shape, \
1509-
f'qq: {qq.shape}, q: {query.shape}, kk: {kk.shape}, key: {key.shape}'
1510-
query, key = qq, kk
1511-
else:
1512-
qq, _ = apply_rotary_emb(query, None, temb, head_first=True)
1513-
assert qq.shape == query.shape, f'qq: {qq.shape}, query: {query.shape}'
1514-
query = qq
1506+
if image_rotary_emb is not None:
1507+
query = apply_rotary_emb(query, image_rotary_emb)
1508+
if not attn.is_cross_attention:
1509+
key = apply_rotary_emb(key, image_rotary_emb)
15151510

15161511
# the output of sdp = (batch, num_heads, seq_len, head_dim)
15171512
# TODO: add support for attn.scale when we move to Torch 2.1
@@ -1537,6 +1532,7 @@ def __call__(
15371532

15381533
return hidden_states
15391534

1535+
15401536
class FusedAttnProcessor2_0:
15411537
r"""
15421538
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
@@ -2808,84 +2804,3 @@ def __call__(
28082804
LoRAXFormersAttnProcessor,
28092805
LoRAAttnAddedKVProcessor,
28102806
]
2811-
2812-
from typing import Tuple
2813-
2814-
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
2815-
"""
2816-
Reshape frequency tensor for broadcasting it with another tensor.
2817-
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
2818-
for the purpose of broadcasting the frequency tensor during element-wise operations.
2819-
Args:
2820-
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
2821-
x (torch.Tensor): Target tensor for broadcasting compatibility.
2822-
head_first (bool): head dimension first (except batch dim) or not.
2823-
Returns:
2824-
torch.Tensor: Reshaped frequency tensor.
2825-
Raises:
2826-
AssertionError: If the frequency tensor doesn't match the expected shape.
2827-
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
2828-
"""
2829-
ndim = x.ndim
2830-
assert 0 <= 1 < ndim
2831-
2832-
if isinstance(freqs_cis, tuple):
2833-
# freqs_cis: (cos, sin) in real space
2834-
if head_first:
2835-
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
2836-
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
2837-
else:
2838-
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
2839-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
2840-
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
2841-
else:
2842-
# freqs_cis: values in complex space
2843-
if head_first:
2844-
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
2845-
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
2846-
else:
2847-
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
2848-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
2849-
return freqs_cis.view(*shape)
2850-
2851-
2852-
def rotate_half(x):
2853-
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
2854-
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
2855-
2856-
def apply_rotary_emb(
2857-
xq: torch.Tensor,
2858-
xk: Optional[torch.Tensor],
2859-
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
2860-
head_first: bool = False,
2861-
) -> Tuple[torch.Tensor, torch.Tensor]:
2862-
"""
2863-
Apply rotary embeddings to input tensors using the given frequency tensor.
2864-
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
2865-
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
2866-
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
2867-
returned as real tensors.
2868-
Args:
2869-
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
2870-
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
2871-
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
2872-
head_first (bool): head dimension first (except batch dim) or not.
2873-
Returns:
2874-
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
2875-
"""
2876-
xk_out = None
2877-
if isinstance(freqs_cis, tuple):
2878-
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
2879-
cos, sin = cos.to(xq.device), sin.to(xq.device)
2880-
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
2881-
if xk is not None:
2882-
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
2883-
else:
2884-
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
2885-
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
2886-
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
2887-
if xk is not None:
2888-
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
2889-
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
2890-
2891-
return xq_out, xk_out

0 commit comments

Comments
 (0)