Skip to content

Commit 2e7bc12

Browse files
rootroot
authored andcommitted
add yiyi attention refactor; removed timm
1 parent 154e150 commit 2e7bc12

19 files changed

+355
-389
lines changed

examples/community/README.md

100755100644
File mode changed.

examples/community/dps_pipeline.py

100755100644
File mode changed.

examples/community/latent_consistency_txt2img.py

100755100644
File mode changed.

examples/community/one_step_unet.py

100755100644
File mode changed.

examples/community/sd_text2img_k_diffusion.py

100755100644
File mode changed.

examples/community/stable_diffusion_tensorrt_img2img.py

100755100644
File mode changed.

examples/community/stable_diffusion_tensorrt_inpaint.py

100755100644
File mode changed.

examples/community/stable_diffusion_tensorrt_txt2img.py

100755100644
File mode changed.

scripts/convert_dance_diffusion_to_diffusers.py

100755100644
File mode changed.

src/diffusers/models/attention.py

Lines changed: 34 additions & 239 deletions
Original file line numberDiff line numberDiff line change
@@ -102,93 +102,6 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
102102

103103
from typing import Tuple, Union, Optional
104104

105-
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
106-
"""
107-
Reshape frequency tensor for broadcasting it with another tensor.
108-
109-
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
110-
for the purpose of broadcasting the frequency tensor during element-wise operations.
111-
112-
Args:
113-
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
114-
x (torch.Tensor): Target tensor for broadcasting compatibility.
115-
head_first (bool): head dimension first (except batch dim) or not.
116-
117-
Returns:
118-
torch.Tensor: Reshaped frequency tensor.
119-
120-
Raises:
121-
AssertionError: If the frequency tensor doesn't match the expected shape.
122-
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
123-
"""
124-
ndim = x.ndim
125-
assert 0 <= 1 < ndim
126-
127-
if isinstance(freqs_cis, tuple):
128-
# freqs_cis: (cos, sin) in real space
129-
if head_first:
130-
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
131-
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
132-
else:
133-
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
134-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
135-
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
136-
else:
137-
# freqs_cis: values in complex space
138-
if head_first:
139-
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
140-
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
141-
else:
142-
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
143-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
144-
return freqs_cis.view(*shape)
145-
146-
147-
def rotate_half(x):
148-
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
149-
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
150-
151-
def apply_rotary_emb(
152-
xq: torch.Tensor,
153-
xk: Optional[torch.Tensor],
154-
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
155-
head_first: bool = False,
156-
) -> Tuple[torch.Tensor, torch.Tensor]:
157-
"""
158-
Apply rotary embeddings to input tensors using the given frequency tensor.
159-
160-
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
161-
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
162-
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
163-
returned as real tensors.
164-
165-
Args:
166-
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
167-
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
168-
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
169-
head_first (bool): head dimension first (except batch dim) or not.
170-
171-
Returns:
172-
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
173-
174-
"""
175-
xk_out = None
176-
if isinstance(freqs_cis, tuple):
177-
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
178-
cos, sin = cos.to(xq.device), sin.to(xq.device)
179-
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
180-
if xk is not None:
181-
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
182-
else:
183-
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
184-
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
185-
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
186-
if xk is not None:
187-
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
188-
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
189-
190-
return xq_out, xk_out
191-
192105
class HunyuanDiTAttentionPool(nn.Module):
193106
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
194107
super().__init__()
@@ -223,147 +136,13 @@ def forward(self, x):
223136
need_weights=False
224137
)
225138
return x.squeeze(0)
226-
227-
class HunyuanDiTCrossAttention(nn.Module):
228-
"""
229-
Use QK Normalization.
230-
"""
231-
def __init__(self,
232-
qdim,
233-
kdim,
234-
num_heads,
235-
qkv_bias=True,
236-
qk_norm=False,
237-
attn_drop=0.0,
238-
proj_drop=0.0,
239-
device=None,
240-
dtype=None,
241-
norm_layer=nn.LayerNorm,
242-
):
243-
factory_kwargs = {'device': device, 'dtype': dtype}
244-
super().__init__()
245-
self.qdim = qdim
246-
self.kdim = kdim
247-
self.num_heads = num_heads
248-
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
249-
self.head_dim = self.qdim // num_heads
250-
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
251-
self.scale = self.head_dim ** -0.5
252-
253-
self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
254-
self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
255-
256-
# TODO: eps should be 1 / 65530 if using fp16
257-
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
258-
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
259-
self.attn_drop = nn.Dropout(attn_drop)
260-
self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
261-
self.proj_drop = nn.Dropout(proj_drop)
262-
263-
def forward(self, x, y, freqs_cis_img=None):
264-
"""
265-
Parameters
266-
----------
267-
x: torch.Tensor
268-
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
269-
y: torch.Tensor
270-
(batch, seqlen2, hidden_dim2)
271-
freqs_cis_img: torch.Tensor
272-
(batch, hidden_dim // 2), RoPE for image
273-
"""
274-
b, s1, c = x.shape # [b, s1, D]
275-
_, s2, c = y.shape # [b, s2, 1024]
276-
277-
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
278-
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
279-
k, v = kv.unbind(dim=2) # [b, s, h, d]
280-
q = self.q_norm(q)
281-
k = self.k_norm(k)
282-
283-
# Apply RoPE if needed
284-
if freqs_cis_img is not None:
285-
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
286-
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
287-
q = qq
288-
289-
q = q * self.scale
290-
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
291-
k = k.permute(0, 2, 3, 1).contiguous() # k -> B, L2, H, C - B, H, C, L2
292-
attn = q @ k # attn -> B, H, L1, L2
293-
attn = attn.softmax(dim=-1) # attn -> B, H, L1, L2
294-
attn = self.attn_drop(attn)
295-
x = attn @ v.transpose(-2, -3) # v -> B, L2, H, C - B, H, L2, C x-> B, H, L1, C
296-
context = x.transpose(1, 2) # context -> B, H, L1, C - B, L1, H, C
297-
298-
context = context.contiguous().view(b, s1, -1)
299-
300-
out = self.out_proj(context) # context.reshape - B, L1, -1
301-
out = self.proj_drop(out)
302-
303-
out_tuple = (out,)
304-
305-
return out_tuple
306-
307-
308-
class HunyuanDiTAttention(nn.Module):
309-
"""
310-
We rename some layer names to align with flash attention
311-
"""
312-
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0.,
313-
norm_layer=nn.LayerNorm,
314-
):
315-
super().__init__()
316-
self.dim = dim
317-
self.num_heads = num_heads
318-
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
319-
self.head_dim = self.dim // num_heads
320-
# This assertion is aligned with flash attention
321-
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
322-
self.scale = self.head_dim ** -0.5
323-
324-
# qkv --> Wqkv
325-
self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
326-
# TODO: eps should be 1 / 65530 if using fp16
327-
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
328-
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
329-
self.attn_drop = nn.Dropout(attn_drop)
330-
self.out_proj = nn.Linear(dim, dim)
331-
self.proj_drop = nn.Dropout(proj_drop)
332-
333-
def forward(self, x, freqs_cis_img=None):
334-
B, N, C = x.shape
335-
qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
336-
q, k, v = qkv.unbind(0) # [b, h, s, d]
337-
q = self.q_norm(q) # [b, h, s, d]
338-
k = self.k_norm(k) # [b, h, s, d]
339-
340-
# Apply RoPE if needed
341-
if freqs_cis_img is not None:
342-
qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
343-
assert qq.shape == q.shape and kk.shape == k.shape, \
344-
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
345-
q, k = qq, kk
346-
347-
q = q * self.scale
348-
attn = q @ k.transpose(-2, -1) # [b, h, s, d] @ [b, h, d, s]
349-
attn = attn.softmax(dim=-1) # [b, h, s, s]
350-
attn = self.attn_drop(attn)
351-
x = attn @ v # [b, h, s, d]
352-
353-
x = x.transpose(1, 2).reshape(B, N, C) # [b, s, h, d]
354-
x = self.out_proj(x)
355-
x = self.proj_drop(x)
356-
357-
out_tuple = (x,)
358-
359-
return out_tuple
360139
### ==== end ====
361140

141+
362142
@maybe_allow_in_graph
363143
class HunyuanDiTBlock(nn.Module):
364144
r"""
365145
HunyuanDiT Transformer block. Allow skip connection and QKNorm
366-
367146
Parameters:
368147
dim (`int`): The number of channels in the input and output.
369148
num_attention_heads (`int`): The number of heads to use for multi-head attention.
@@ -416,19 +195,36 @@ def __init__(
416195
# 1. Self-Attn
417196
self.norm1 = FP32_Layernorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
418197

419-
self.attn1 = HunyuanDiTAttention(dim, num_heads=num_attention_heads, qkv_bias=True, qk_norm=qk_norm)
198+
from .attention_processor import HunyuanAttnProcessor2_0
199+
self.attn1 = Attention(
200+
query_dim=dim,
201+
cross_attention_dim=dim,
202+
dim_head = dim //num_attention_heads,
203+
heads = num_attention_heads,
204+
qk_norm="layer_norm" if qk_norm else None,
205+
eps=1e-6,
206+
bias=True,
207+
processor= HunyuanAttnProcessor2_0(),
208+
)
420209

421210
# 2. Cross-Attn
422211
self.norm3 = FP32_Layernorm(dim, norm_eps, norm_elementwise_affine)
423212

424-
self.attn2 = HunyuanDiTCrossAttention(dim, text_dim, num_heads=num_attention_heads, qkv_bias=True, qk_norm=qk_norm)
425-
213+
self.attn2 = Attention(
214+
query_dim=dim,
215+
cross_attention_dim=text_dim,
216+
dim_head = dim // num_attention_heads,
217+
heads = num_attention_heads,
218+
qk_norm="layer_norm" if qk_norm else None,
219+
eps=1e-6,
220+
bias=True,
221+
processor= HunyuanAttnProcessor2_0(),
222+
)
426223
# 3. Feed-forward
427224
self.norm2 = FP32_Layernorm(dim, norm_eps, norm_elementwise_affine)
428225

429-
### NOTE: do not switch norm2 and norm3, otherwise will load wrong key when using pretrained model!
226+
### TODO: switch norm2 and norm3 in the state dict
430227

431-
#print('mlp hidden dim:', ff_inner_dim)
432228
self.ff = FeedForward(
433229
dim,
434230
dropout=dropout, ### 0.0
@@ -475,28 +271,27 @@ def forward(
475271
cat = torch.cat([hidden_states, skip], dim=-1)
476272
cat = self.skip_norm(cat)
477273
hidden_states = self.skip_linear(cat)
478-
479-
#print('x:', hidden_states[0])
274+
480275
# 1. Self-Attention
481276
norm_hidden_states = self.norm1(hidden_states) ### checked: self.norm1 is correct
482277
shift_msa = self.default_modulation(timestep).unsqueeze(dim=1)
483-
attn_inputs = (norm_hidden_states + shift_msa, freq_cis_img,)
484-
attn_output = self.attn1(*attn_inputs)[0]
278+
attn_output = self.attn1(
279+
norm_hidden_states + shift_msa,
280+
temb = freq_cis_img,
281+
)
485282
hidden_states = hidden_states + attn_output
486-
#print('x:', hidden_states[0])
487283

488284
# 2. Cross-Attention
489-
cross_inputs = (
490-
self.norm3(hidden_states), encoder_hidden_states, freq_cis_img
285+
hidden_states = hidden_states + self.attn2(
286+
self.norm3(hidden_states),
287+
encoder_hidden_states = encoder_hidden_states,
288+
temb = freq_cis_img,
491289
)
492-
hidden_states = hidden_states + self.attn2(*cross_inputs)[0]
493-
#print('x:', hidden_states[0])
494290

495-
# FFN Layer ### NOTE: do not switch norm2 and norm3, otherwise will load wrong key when using pretrained model!
291+
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
496292
mlp_inputs = self.norm2(hidden_states)
497293
hidden_states = hidden_states + self.ff(mlp_inputs)
498-
#print('x:', hidden_states[0])
499-
294+
500295
return hidden_states
501296

502297
@maybe_allow_in_graph

0 commit comments

Comments
 (0)