Skip to content

Commit a83f85d

Browse files
gnobitabXCLiuyiyixuxu
authored andcommitted
Tencent Hunyuan Team: add HunyuanDiT related updates (#8240)
* Hunyuan Team: add HunyuanDiT related updates --------- Co-authored-by: XCLiu <[email protected]> Co-authored-by: yiyixuxu <[email protected]>
1 parent 902d799 commit a83f85d

File tree

15 files changed

+1999
-9
lines changed

15 files changed

+1999
-9
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
"ControlNetModel",
8484
"ControlNetXSAdapter",
8585
"DiTTransformer2DModel",
86+
"HunyuanDiT2DModel",
8687
"I2VGenXLUNet",
8788
"Kandinsky3UNet",
8889
"ModelMixin",
@@ -229,6 +230,7 @@
229230
"BlipDiffusionPipeline",
230231
"CLIPImageProjection",
231232
"CycleDiffusionPipeline",
233+
"HunyuanDiTPipeline",
232234
"I2VGenXLPipeline",
233235
"IFImg2ImgPipeline",
234236
"IFImg2ImgSuperResolutionPipeline",
@@ -487,6 +489,7 @@
487489
ControlNetModel,
488490
ControlNetXSAdapter,
489491
DiTTransformer2DModel,
492+
HunyuanDiT2DModel,
490493
I2VGenXLUNet,
491494
Kandinsky3UNet,
492495
ModelMixin,
@@ -611,6 +614,7 @@
611614
AudioLDMPipeline,
612615
CLIPImageProjection,
613616
CycleDiffusionPipeline,
617+
HunyuanDiTPipeline,
614618
I2VGenXLPipeline,
615619
IFImg2ImgPipeline,
616620
IFImg2ImgSuperResolutionPipeline,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
_import_structure["embeddings"] = ["ImageProjection"]
3838
_import_structure["modeling_utils"] = ["ModelMixin"]
3939
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
40+
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
4041
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
4142
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
4243
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
@@ -77,6 +78,7 @@
7778
from .transformers import (
7879
DiTTransformer2DModel,
7980
DualTransformer2DModel,
81+
HunyuanDiT2DModel,
8082
PixArtTransformer2DModel,
8183
PriorTransformer,
8284
T5FilmDecoder,

src/diffusers/models/activations.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ def get_activation(act_fn: str) -> nn.Module:
5050
raise ValueError(f"Unsupported activation function: {act_fn}")
5151

5252

53+
class FP32SiLU(nn.Module):
54+
r"""
55+
SiLU activation function with input upcasted to torch.float32.
56+
"""
57+
58+
def __init__(self):
59+
super().__init__()
60+
61+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
62+
return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
63+
64+
5365
class GELU(nn.Module):
5466
r"""
5567
GELU activation function with tanh approximation support with `approximate="tanh"`.

src/diffusers/models/attention_processor.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
upcast_softmax: bool = False,
104104
cross_attention_norm: Optional[str] = None,
105105
cross_attention_norm_num_groups: int = 32,
106+
qk_norm: Optional[str] = None,
106107
added_kv_proj_dim: Optional[int] = None,
107108
norm_num_groups: Optional[int] = None,
108109
spatial_norm_dim: Optional[int] = None,
@@ -161,6 +162,15 @@ def __init__(
161162
else:
162163
self.spatial_norm = None
163164

165+
if qk_norm is None:
166+
self.norm_q = None
167+
self.norm_k = None
168+
elif qk_norm == "layer_norm":
169+
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
170+
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
171+
else:
172+
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
173+
164174
if cross_attention_norm is None:
165175
self.norm_cross = None
166176
elif cross_attention_norm == "layer_norm":
@@ -1426,6 +1436,104 @@ def __call__(
14261436
return hidden_states
14271437

14281438

1439+
class HunyuanAttnProcessor2_0:
1440+
r"""
1441+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
1442+
used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
1443+
"""
1444+
1445+
def __init__(self):
1446+
if not hasattr(F, "scaled_dot_product_attention"):
1447+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1448+
1449+
def __call__(
1450+
self,
1451+
attn: Attention,
1452+
hidden_states: torch.Tensor,
1453+
encoder_hidden_states: Optional[torch.Tensor] = None,
1454+
attention_mask: Optional[torch.Tensor] = None,
1455+
temb: Optional[torch.Tensor] = None,
1456+
image_rotary_emb: Optional[torch.Tensor] = None,
1457+
) -> torch.Tensor:
1458+
from .embeddings import apply_rotary_emb
1459+
1460+
residual = hidden_states
1461+
if attn.spatial_norm is not None:
1462+
hidden_states = attn.spatial_norm(hidden_states, temb)
1463+
1464+
input_ndim = hidden_states.ndim
1465+
1466+
if input_ndim == 4:
1467+
batch_size, channel, height, width = hidden_states.shape
1468+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1469+
1470+
batch_size, sequence_length, _ = (
1471+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1472+
)
1473+
1474+
if attention_mask is not None:
1475+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1476+
# scaled_dot_product_attention expects attention_mask shape to be
1477+
# (batch, heads, source_length, target_length)
1478+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1479+
1480+
if attn.group_norm is not None:
1481+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1482+
1483+
query = attn.to_q(hidden_states)
1484+
1485+
if encoder_hidden_states is None:
1486+
encoder_hidden_states = hidden_states
1487+
elif attn.norm_cross:
1488+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1489+
1490+
key = attn.to_k(encoder_hidden_states)
1491+
value = attn.to_v(encoder_hidden_states)
1492+
1493+
inner_dim = key.shape[-1]
1494+
head_dim = inner_dim // attn.heads
1495+
1496+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1497+
1498+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1499+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1500+
1501+
if attn.norm_q is not None:
1502+
query = attn.norm_q(query)
1503+
if attn.norm_k is not None:
1504+
key = attn.norm_k(key)
1505+
1506+
# Apply RoPE if needed
1507+
if image_rotary_emb is not None:
1508+
query = apply_rotary_emb(query, image_rotary_emb)
1509+
if not attn.is_cross_attention:
1510+
key = apply_rotary_emb(key, image_rotary_emb)
1511+
1512+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
1513+
# TODO: add support for attn.scale when we move to Torch 2.1
1514+
hidden_states = F.scaled_dot_product_attention(
1515+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1516+
)
1517+
1518+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1519+
hidden_states = hidden_states.to(query.dtype)
1520+
1521+
# linear proj
1522+
hidden_states = attn.to_out[0](hidden_states)
1523+
# dropout
1524+
hidden_states = attn.to_out[1](hidden_states)
1525+
1526+
if input_ndim == 4:
1527+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1528+
1529+
if attn.residual_connection:
1530+
hidden_states = hidden_states + residual
1531+
1532+
hidden_states = hidden_states / attn.rescale_output_factor
1533+
1534+
return hidden_states
1535+
1536+
14291537
class FusedAttnProcessor2_0:
14301538
r"""
14311539
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses

0 commit comments

Comments
 (0)