Skip to content

[do not merge] refactor Attention class for HunYuan DIT #8265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSAdapter",
"HunyuanDiT2DModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
"ModelMixin",
Expand Down Expand Up @@ -227,6 +228,7 @@
"BlipDiffusionPipeline",
"CLIPImageProjection",
"CycleDiffusionPipeline",
"HunyuanDiTPipeline",
"I2VGenXLPipeline",
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
Expand Down Expand Up @@ -482,6 +484,7 @@
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSAdapter,
HunyuanDiT2DModel,
I2VGenXLUNet,
Kandinsky3UNet,
ModelMixin,
Expand Down Expand Up @@ -605,6 +608,7 @@
AudioLDMPipeline,
CLIPImageProjection,
CycleDiffusionPipeline,
HunyuanDiTPipeline,
I2VGenXLPipeline,
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
Expand Down Expand Up @@ -73,6 +74,7 @@
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
HunyuanDiT2DModel,
DualTransformer2DModel,
PriorTransformer,
T5FilmDecoder,
Expand Down
219 changes: 219 additions & 0 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,225 @@ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:

return x

### TODO: XCLiu: some ugly helper functions, please clean later
### ==== begin ====
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class FP32_Layernorm(nn.LayerNorm):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(),
self.eps).to(origin_dtype)


class FP32_SiLU(nn.SiLU):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype)

from typing import Tuple, Union, Optional



class HunyuanDiTAttentionPool(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads

def forward(self, x):
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)


### ==== end ====

@maybe_allow_in_graph
class HunyuanDiTBlock(nn.Module):
r"""
HunyuanDiT Transformer block. Allow skip connection and QKNorm

Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""

def __init__(
self,
dim: int,
num_attention_heads: int,
text_dim: int=1024,
dropout=0.0,
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-6,
final_dropout: bool = False,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
skip: bool = False,
qk_norm: bool = True,
):
super().__init__()

# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm1 = FP32_Layernorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)

from .attention_processor import HunyuanAttnProcessor2_0
self.attn1 = Attention(
query_dim=dim,
cross_attention_dim=dim,
dim_head = dim //num_attention_heads,
heads = num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=True,
processor= HunyuanAttnProcessor2_0(),
)

# 2. Cross-Attn
self.norm3 = FP32_Layernorm(dim, norm_eps, norm_elementwise_affine)

self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=text_dim,
dim_head = dim // num_attention_heads,
heads = num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=True,
processor= HunyuanAttnProcessor2_0(),
)
# 3. Feed-forward
self.norm2 = FP32_Layernorm(dim, norm_eps, norm_elementwise_affine)

### NOTE: do not switch norm2 and norm3, otherwise will load wrong key when using pretrained model!

#print('mlp hidden dim:', ff_inner_dim)
self.ff = FeedForward(
dim,
dropout=dropout, ### 0.0
activation_fn=activation_fn, ### approx GeLU
final_dropout=final_dropout, ### 0.0
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
bias=ff_bias,
)

# 4. Skip Connection
if skip:
self.skip_norm = FP32_Layernorm(2 * dim, norm_eps, elementwise_affine=True)
self.skip_linear = nn.Linear(2 * dim, dim)
else:
self.skip_linear = None

# 5. SDXL-style modulation with add
self.default_modulation = nn.Sequential(
FP32_SiLU(),
nn.Linear(dim, dim, bias=True)
)

# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0

def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
freq_cis_img = None,
skip=None
) -> torch.Tensor:

# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([hidden_states, skip], dim=-1)
cat = self.skip_norm(cat)
hidden_states = self.skip_linear(cat)

#print('x:', hidden_states[0])
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states) ### checked: self.norm1 is correct
shift_msa = self.default_modulation(timestep).unsqueeze(dim=1)
attn_output = self.attn1(
norm_hidden_states + shift_msa,
temb = freq_cis_img,
)
hidden_states = hidden_states + attn_output
#print('x:', hidden_states[0])

# 2. Cross-Attention
hidden_states = hidden_states + self.attn2(
self.norm3(hidden_states),
encoder_hidden_states = encoder_hidden_states,
temb = freq_cis_img,
)

#print('x:', hidden_states[0])

# FFN Layer ### NOTE: do not switch norm2 and norm3, otherwise will load wrong key when using pretrained model!
mlp_inputs = self.norm2(hidden_states)
hidden_states = hidden_states + self.ff(mlp_inputs)
#print('x:', hidden_states[0])

return hidden_states

@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
Expand Down
Loading
Loading