diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 66c98804eadc..13bf9b5cd16b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -82,6 +82,7 @@ "ConsistencyDecoderVAE", "ControlNetModel", "ControlNetXSAdapter", + "HunyuanDiT2DModel", "I2VGenXLUNet", "Kandinsky3UNet", "ModelMixin", @@ -227,6 +228,7 @@ "BlipDiffusionPipeline", "CLIPImageProjection", "CycleDiffusionPipeline", + "HunyuanDiTPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -482,6 +484,7 @@ ConsistencyDecoderVAE, ControlNetModel, ControlNetXSAdapter, + HunyuanDiT2DModel, I2VGenXLUNet, Kandinsky3UNet, ModelMixin, @@ -605,6 +608,7 @@ AudioLDMPipeline, CLIPImageProjection, CycleDiffusionPipeline, + HunyuanDiTPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 78b0efff921d..04e69d5e3682 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -39,6 +39,7 @@ _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] + _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] @@ -73,6 +74,7 @@ from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( + HunyuanDiT2DModel, DualTransformer2DModel, PriorTransformer, T5FilmDecoder, diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 3d4fccb20779..b05737de1841 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -84,6 +84,225 @@ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: return x +### TODO: XCLiu: some ugly helper functions, please clean later +### ==== begin ==== +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + +class FP32_Layernorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), + self.eps).to(origin_dtype) + + +class FP32_SiLU(nn.SiLU): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype) + +from typing import Tuple, Union, Optional + + + +class HunyuanDiTAttentionPool(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +### ==== end ==== + +@maybe_allow_in_graph +class HunyuanDiTBlock(nn.Module): + r""" + HunyuanDiT Transformer block. Allow skip connection and QKNorm + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + text_dim: int=1024, + dropout=0.0, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-6, + final_dropout: bool = False, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + skip: bool = False, + qk_norm: bool = True, + ): + super().__init__() + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = FP32_Layernorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + from .attention_processor import HunyuanAttnProcessor2_0 + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=dim, + dim_head = dim //num_attention_heads, + heads = num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor= HunyuanAttnProcessor2_0(), + ) + + # 2. Cross-Attn + self.norm3 = FP32_Layernorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=text_dim, + dim_head = dim // num_attention_heads, + heads = num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor= HunyuanAttnProcessor2_0(), + ) + # 3. Feed-forward + self.norm2 = FP32_Layernorm(dim, norm_eps, norm_elementwise_affine) + + ### NOTE: do not switch norm2 and norm3, otherwise will load wrong key when using pretrained model! + + #print('mlp hidden dim:', ff_inner_dim) + self.ff = FeedForward( + dim, + dropout=dropout, ### 0.0 + activation_fn=activation_fn, ### approx GeLU + final_dropout=final_dropout, ### 0.0 + inner_dim=ff_inner_dim, ### int(dim * mlp_ratio) + bias=ff_bias, + ) + + # 4. Skip Connection + if skip: + self.skip_norm = FP32_Layernorm(2 * dim, norm_eps, elementwise_affine=True) + self.skip_linear = nn.Linear(2 * dim, dim) + else: + self.skip_linear = None + + # 5. SDXL-style modulation with add + self.default_modulation = nn.Sequential( + FP32_SiLU(), + nn.Linear(dim, dim, bias=True) + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + freq_cis_img = None, + skip=None + ) -> torch.Tensor: + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Long Skip Connection + if self.skip_linear is not None: + cat = torch.cat([hidden_states, skip], dim=-1) + cat = self.skip_norm(cat) + hidden_states = self.skip_linear(cat) + + #print('x:', hidden_states[0]) + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states) ### checked: self.norm1 is correct + shift_msa = self.default_modulation(timestep).unsqueeze(dim=1) + attn_output = self.attn1( + norm_hidden_states + shift_msa, + temb = freq_cis_img, + ) + hidden_states = hidden_states + attn_output + #print('x:', hidden_states[0]) + + # 2. Cross-Attention + hidden_states = hidden_states + self.attn2( + self.norm3(hidden_states), + encoder_hidden_states = encoder_hidden_states, + temb = freq_cis_img, + ) + + #print('x:', hidden_states[0]) + + # FFN Layer ### NOTE: do not switch norm2 and norm3, otherwise will load wrong key when using pretrained model! + mlp_inputs = self.norm2(hidden_states) + hidden_states = hidden_states + self.ff(mlp_inputs) + #print('x:', hidden_states[0]) + + return hidden_states @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cbb07eafa37f..0f46c35fb07d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -103,6 +103,7 @@ def __init__( upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, @@ -160,6 +161,14 @@ def __init__( self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) else: self.spatial_norm = None + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_k = nn.LayerNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") if cross_attention_norm is None: self.norm_cross = None @@ -1426,6 +1435,109 @@ def __call__( return hidden_states +class HunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + apply_rotary_emb_on_key = False + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + apply_rotary_emb_on_key = True + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if temb is not None: + if apply_rotary_emb_on_key: + qq, kk = apply_rotary_emb(query, key, temb, head_first=True) + assert qq.shape == query.shape and kk.shape == key.shape, \ + f'qq: {qq.shape}, q: {query.shape}, kk: {kk.shape}, key: {key.shape}' + query, key = qq, kk + else: + qq, _ = apply_rotary_emb(query, None, temb, head_first=True) + assert qq.shape == query.shape, f'qq: {qq.shape}, query: {query.shape}' + query = qq + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses @@ -2697,3 +2809,91 @@ def __call__( LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor, ] + +from typing import Tuple +def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + if head_first: + assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}' + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}' + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + # freqs_cis: values in complex space + if head_first: + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}' + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}' + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + +def apply_rotary_emb( + xq: torch.Tensor, + xk: Optional[torch.Tensor], + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xk_out = None + if isinstance(freqs_cis, tuple): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] + cos, sin = cos.to(xq.device), sin.to(xq.device) + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + if xk is not None: + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + else: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + if xk is not None: + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index dc78a72b2fb8..f1067f3313ca 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -7,3 +7,4 @@ from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .transformer_temporal import TransformerTemporalModel + from .hunyuan_transformer_2d import HunyuanDiT2DModel diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py new file mode 100644 index 000000000000..600a7c4ba57b --- /dev/null +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -0,0 +1,412 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, deprecate, is_torch_version, logging +from ..attention import HunyuanDiTBlock, FP32_SiLU, FP32_Layernorm, HunyuanDiTAttentionPool, modulate +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.Tensor + +import math +from einops import repeat +from timm.models.layers import to_2tuple +class HunyuanDiTPatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + + Image to Patch Embedding using Conv2d + + A convolution based approach to patchifying a 2D image w/ embedding projection. + + Based on the impl in https://github.com/google-research/vision_transformer + + Hacked together by / Copyright 2020 Ross Wightman + + Remove the _assert function in forward function to be compatible with multi-resolution images. + """ + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + ): + super().__init__() + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, (tuple, list)) and len(img_size) == 2: + img_size = tuple(img_size) + else: + raise ValueError(f"img_size must be int or tuple/list of length 2. Got {img_size}") + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def update_image_size(self, img_size): + self.img_size = img_size + self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + def forward(self, x): + # B, C, H, W = x.shape + # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + +def timestep_embedding(t, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线 + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(t, "b -> b d", d=dim) + return embedding + + +class HunyuanDiTTimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None): + super().__init__() + if out_size is None: + out_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + +class HunyuanDiT2DModel(ModelMixin, ConfigMixin): + """ + HunYuanDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters + ---------- + args: argparse.Namespace + The arguments parsed by argparse. + input_size: tuple + The size of the input image. + patch_size: int + The size of the patch. + in_channels: int + The number of input channels. + hidden_size: int + The hidden size of the transformer backbone. + depth: int + The number of transformer blocks. + num_heads: int + The number of attention heads. + mlp_ratio: float + The ratio of the hidden size of the MLP in the transformer block. + log_fn: callable + The logging function. + """ + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "gelu-approximate", + input_size=(32, 32), + hidden_size=1152, + num_layers: int = 28, + mlp_ratio: float=4.0, + learn_sigma: bool=True, + text_dim: int=1024, + norm_type: str = "layer_norm", + ): + super().__init__() + self.depth = num_layers + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_attention_heads + self.hidden_size = hidden_size + self.text_states_dim = text_dim + self.text_states_dim_t5 = 2048 + self.text_len = 77 + self.text_len_t5 = 256 ### NOTE: These numbers are hardcoded for now, seems will not change in near future + self.norm = norm_type + self.inner_dim = num_attention_heads * attention_head_dim + + self.mlp_t5 = nn.Sequential( + nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True), + FP32_SiLU(), + nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True), + ) + # learnable replace + self.text_embedding_padding = nn.Parameter( + torch.randn(self.text_len + self.text_len_t5, self.text_states_dim, dtype=torch.float32)) + + # Attention pooling + self.pooler = HunyuanDiTAttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=1024) + + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding(1, hidden_size) + + # Image size and crop size conditions + self.extra_in_dim = 256 * 6 + hidden_size + + # Text embedding for `add` + self.x_embedder = HunyuanDiTPatchEmbed(input_size, patch_size, in_channels, hidden_size) + self.t_embedder = HunyuanDiTTimestepEmbedder(hidden_size) + self.extra_in_dim += 1024 + self.extra_embedder = nn.Sequential( + nn.Linear(self.extra_in_dim, hidden_size * 4), + FP32_SiLU(), + nn.Linear(hidden_size * 4, hidden_size, bias=True), + ) + + # HunyuanDiT Blocks + self.blocks = nn.ModuleList([ + HunyuanDiTBlock(dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + activation_fn=activation_fn, + ff_inner_dim=int(self.inner_dim * mlp_ratio), + text_dim=self.config.text_dim, + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + skip=layer > self.depth // 2, + ) + for layer in range(self.depth) + ]) + + self.norm_final = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.final_linear = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + self.final_adaLN_modulation = nn.Sequential( + FP32_SiLU(), + nn.Linear(hidden_size, 2 * self.inner_dim, bias=True) + ) + self.unpatchify_channels = self.out_channels + + self.initialize_weights() + + + def forward(self, + x, + t, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + cos_cis_img=None, + sin_cis_img=None, + return_dict=True, + ): + """ + Forward pass of the encoder. + + Parameters + ---------- + x: torch.Tensor + (B, D, H, W) + t: torch.Tensor + (B) + encoder_hidden_states: torch.Tensor + CLIP text embedding, (B, L_clip, D) + text_embedding_mask: torch.Tensor + CLIP text embedding mask, (B, L_clip) + encoder_hidden_states_t5: torch.Tensor + T5 text embedding, (B, L_t5, D) + text_embedding_mask_t5: torch.Tensor + T5 text embedding mask, (B, L_t5) + image_meta_size: torch.Tensor + (B, 6) + style: torch.Tensor + (B) + cos_cis_img: torch.Tensor + sin_cis_img: torch.Tensor + return_dict: bool + Whether to return a dictionary. + """ + + text_states = encoder_hidden_states # 2,77,1024 + text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 + text_states_mask = text_embedding_mask.bool() # 2,77 + text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256 + b_t5, l_t5, c_t5 = text_states_t5.shape + text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)) + text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024 + clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1) + + clip_t5_mask = clip_t5_mask + text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states)) + + _, _, oh, ow = x.shape + th, tw = oh // self.patch_size, ow // self.patch_size + + # ========================= Build time and image embedding ========================= + t = self.t_embedder(t) + x = self.x_embedder(x) + + # Get image RoPE embedding according to `reso`lution. + freqs_cis_img = (cos_cis_img, sin_cis_img) + + # ========================= Concatenate all extra vectors ========================= + # Build text tokens with pooling + extra_vec = self.pooler(encoder_hidden_states_t5) + + # Build image meta size tokens + image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256] + + image_meta_size = image_meta_size.to(dtype=x.dtype) + image_meta_size = image_meta_size.view(-1, 6 * 256) + extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256] + + # Build style tokens + style_embedding = self.style_embedder(style) + extra_vec = torch.cat([extra_vec, style_embedding], dim=1) + + # Concatenate all extra vectors + c = t + self.extra_embedder(extra_vec) # [B, D] + + # ========================= Forward pass through HunYuanDiT blocks ========================= + skips = [] + for layer, block in enumerate(self.blocks): + if layer > self.depth // 2: + skip = skips.pop() + x = block(x, timestep=c, encoder_hidden_states=text_states, freq_cis_img=freqs_cis_img, skip=skip) # (N, L, D) + else: + x = block(x, timestep=c, encoder_hidden_states=text_states, freq_cis_img=freqs_cis_img,) # (N, L, D) + + if layer < (self.depth // 2 - 1): + skips.append(x) + + # ========================= Final layer ========================= + x = self._get_output_for_patched_inputs(x, c) # (N, L, patch_size ** 2 * out_channels) + x = self.unpatchify(x, th, tw) # (N, out_channels, H, W) + + if return_dict: + return {'x': x} + return x + + def _get_output_for_patched_inputs(self, hidden_states, timestep): + shift, scale = self.final_adaLN_modulation(timestep).chunk(2, dim=1) + hidden_states = modulate(self.norm_final(hidden_states), shift, scale) + hidden_states = self.final_linear(hidden_states) + return hidden_states + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.extra_embedder[0].weight, std=0.02) + nn.init.normal_(self.extra_embedder[2].weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in HunYuanDiT blocks: + for block in self.blocks: + nn.init.constant_(block.default_modulation[-1].weight, 0) + nn.init.constant_(block.default_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_linear.weight, 0) + nn.init.constant_(self.final_linear.bias, 0) + + def unpatchify(self, x, h, w): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.unpatchify_channels + p = self.x_embedder.patch_size[0] + # h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c2dd7ac0d551..da20f4c7b412 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -149,6 +149,7 @@ "IFPipeline", "IFSuperResolutionPipeline", ] + _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", @@ -411,6 +412,7 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) + from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( KandinskyCombinedPipeline, diff --git a/src/diffusers/pipelines/hunyuandit/__init__.py b/src/diffusers/pipelines/hunyuandit/__init__.py new file mode 100644 index 000000000000..8337399106f0 --- /dev/null +++ b/src/diffusers/pipelines/hunyuandit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuandit"] = ["HunyuanDiTPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuandit import HunyuanDiTPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py new file mode 100644 index 000000000000..8531c521b8bd --- /dev/null +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -0,0 +1,1120 @@ +# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import PIL +import numpy as np +import urllib.parse as ul +from typing import Callable, List, Optional, Tuple, Union, Dict, Any +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor + +import torch +from transformers import T5EncoderModel, MT5Tokenizer, CLIPImageProcessor +from transformers import BertModel, BertTokenizer +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models.lora import adjust_lora_scale_text_encoder + +import torch.nn as nn + +from ...models import AutoencoderKL, HunyuanDiT2DModel + +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDPMScheduler + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import HunyuanDiTPipeline + + >>> pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT", torch_dtype=torch.float16) + >>> pipe.to('cuda') + + >>> # You may also use English prompt as HunyuanDiT supports both English and Chinese + >>> # prompt = "An astronaut riding a horse" + >>> prompt = "一个宇航员在骑马" + >>> image = pipe(prompt).images[0] + ``` +""" + +STANDARD_RATIO = np.array([ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 +]) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [ + np.array([w * h for w, h in shapes]) + for shapes in STANDARD_SHAPE +] +SUPPORTED_SHAPE = [ + (1024, 1024), (1280, 1280), # 1:1 + (1024, 768), (1152, 864), (1280, 960), # 4:3 + (768, 1024), (864, 1152), (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + +def _to_tuple(x): + if isinstance(x, int): + return x, x + else: + return x + + +def get_fill_resize_and_crop(src, tgt): + th, tw = _to_tuple(tgt) + h, w = _to_tuple(src) + + tr = th / tw + r = h / w + + # resize + if r > tr: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +def get_meshgrid(start, *args): + if len(args) == 0: + # start is grid_size + num = _to_tuple(start) + start = (0, 0) + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start) + stop = _to_tuple(args[0]) + num = (stop[0] - start[0], stop[1] - start[1]) + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start) # up-left eg: 12,0 + stop = _to_tuple(args[0]) # bottom-right eg: 20,32 + num = _to_tuple(args[1]) # target size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + return grid + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = get_meshgrid(start, *args) # [2, H, w] + # grid_h = np.arange(grid_size, dtype=np.float32) + # grid_w = np.arange(grid_size, dtype=np.float32) + # grid = np.meshgrid(grid_w, grid_h) # here w goes first + # grid = np.stack(grid, axis=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (W,H) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443 + +def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True): + """ + This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure. + + Parameters + ---------- + embed_dim: int + embedding dimension size + start: int or tuple of int + If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; + If len(args) == 2, start is start, args[0] is stop, args[1] is num. + use_real: bool + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns + ------- + pos_embed: torch.Tensor + [HW, D/2] + """ + grid = get_meshgrid(start, *args) # [2, H, w] + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): + assert embed_dim % 4 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) + emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) + + if use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) + return cos, sin + else: + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + return emb + + +def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2] + + """ + if isinstance(pos, int): + pos = np.arange(pos) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] + freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + + +def calc_sizes(rope_img, patch_size, th, tw): + """ compute the size of RoPE. """ + if rope_img == 'extend': + sub_args = [(th, tw)] + elif rope_img.startswith('base'): + # Interpolate based on a base size + base_size = int(rope_img[4:]) // 8 // patch_size # 512 as the base + start, stop = get_fill_resize_and_crop((th, tw), base_size) # up-left and bottom-right in 32 by 32 + sub_args = [start, stop, (th, tw)] + else: + raise ValueError(f"Unknown rope_img: {rope_img}") + return sub_args + + +def init_image_posemb(rope_img, + resolutions, + patch_size, + hidden_size, + num_heads, + log_fn, + rope_real=True, + ): + freqs_cis_img = {} + for reso in resolutions: + th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size + sub_args = calc_sizes(rope_img, patch_size, th, tw) # [up-left, bottom-right, target height & width] + freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real) + log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) " + f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}") + return freqs_cis_img + + +def calc_rope(height, width, patch_size, head_size): + th = height // 8 // patch_size + tw = width // 8 // patch_size + base_size = 512 // 8 // patch_size + start, stop = get_fill_resize_and_crop((th, tw), base_size) + sub_args = [start, stop, (th, tw)] + rope = get_2d_rotary_pos_embed(head_size, *sub_args) + return rope + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + +class HunyuanDiTPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + We use `sdxl-vae-fp16-fix`. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + embedder_t5 (`MT5Embedder`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->embedder_t5->tokenizer->tokenizer_t5->transformer->vae" + _optional_components = ["safety_checker", "feature_extractor", "embedder_t5", "tokenizer_t5"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + progress_bar_config: Dict[str, Any] = None, + embedder_t5=T5EncoderModel, + tokenizer_t5=MT5Tokenizer, + infer_mode='torch', + ): + super().__init__() + + # ======================================================== + self.infer_mode = infer_mode + + # ======================================================== + if progress_bar_config is None: + progress_bar_config = {} + if not hasattr(self, '_progress_bar_config'): + self._progress_bar_config = {} + self._progress_bar_config.update(progress_bar_config) + # ======================================================== + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_t5=tokenizer_t5, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + embedder_t5=embedder_t5, + ) + + self.text_encoder.pooler.to_empty(device='cpu') ### workaround for the meta device in pooler... + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + embedder=None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + embedder: + T5 embedder (including text encoder and tokenizer) + """ + if embedder is None: + text_encoder = self.text_encoder + tokenizer = self.tokenizer + max_length = self.tokenizer.model_max_length + else: + text_encoder = embedder['model'] + tokenizer = embedder['tokenizer'] + max_length = embedder['max_length'] + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + attention_mask = attention_mask.repeat(num_images_per_prompt, 1) + else: + attention_mask = None + + if text_encoder is not None: + prompt_embeds_dtype = text_encoder.dtype + elif self.transformer is not None: + prompt_embeds_dtype = self.transformer.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=uncond_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + uncond_attention_mask = uncond_attention_mask.repeat(num_images_per_prompt, 1) + else: + uncond_attention_mask = None + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask + + def _convert_to_rgb(self, image): + return image.convert('RGB') + + def image_transform(self, image_size=224): + transform = T.Compose([ + T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC), + self._convert_to_rgb, + T.ToTensor(), + T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + return transform + + def encode_img(self, img, device, do_classifier_free_guidance): + img = img[0] # TODO: support batch processing + image_preprocess = self.image_transform(224) + img_for_clip = image_preprocess(img) + + img_for_clip = img_for_clip.unsqueeze(0) + img_clip_embedding = self.img_encoder(img_for_clip.to(device)).to(dtype=torch.float16) + + if do_classifier_free_guidance: + negative_img_clip_embedding = torch.zeros_like(img_clip_embedding) + return img_clip_embedding, negative_img_clip_embedding + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + height: int, + width: int, + prompt: Union[str, List[str]] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds_t5: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + image_meta_size: Optional[torch.LongTensor] = None, + style: Optional[torch.LongTensor] = None, + progress: bool = True, + use_fp16: bool = False, + freqs_cis_img: Optional[tuple] = None, + learn_sigma: bool = True, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor, + pred_x0: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + print(prompt) + print(negative_prompt) + + # 2. Calculate neccessary elements for HunyuanDiT + target_height = int((height // 16) * 16) + target_width = int((width // 16) * 16) + print(f"Align to 16: (height, width) = ({target_height}, {target_width})") + + if not (target_height, target_width) in SUPPORTED_SHAPE: + target_width, target_height = map_to_standard_shapes(target_width, target_height) + height = int(target_height) + width = int(target_width) + print(f"Reshaped to (height, width)=({target_height}, {target_width})") + print(f"Supported shapes are {SUPPORTED_SHAPE}") + + freqs_cis_img = calc_rope(target_height, target_width, patch_size=self.transformer.config.patch_size, \ + head_size=self.transformer.inner_dim // self.transformer.num_heads) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 4. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask = \ + self.encode_prompt(prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + prompt_embeds_t5, negative_prompt_embeds_t5, attention_mask_t5, uncond_attention_mask_t5 = \ + self.encode_prompt(prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds_t5, + negative_prompt_embeds=negative_prompt_embeds_t5, + lora_scale=text_encoder_lora_scale, + embedder={'model': self.embedder_t5, 'tokenizer': self.tokenizer_t5, 'max_length': 256}, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_mask = torch.cat([uncond_attention_mask, attention_mask]) + prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5]) + attention_mask_t5 = torch.cat([uncond_attention_mask_t5, attention_mask_t5]) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents(batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # ======================================================================== + # Arguments: style. (A fixed argument. Don't Change it.) + # ======================================================================== + style = torch.as_tensor([0, 0] * batch_size, device=self._execution_device) + + # ======================================================================== + # Inner arguments: image_meta_size (Please refer to SDXL.) + # ======================================================================== + src_size_cond = (1024, 1024) + if isinstance(src_size_cond, int): + src_size_cond = [src_size_cond, src_size_cond] + if not isinstance(src_size_cond, (list, tuple)): + raise TypeError(f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}") + if len(src_size_cond) != 2: + raise ValueError(f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}") + size_cond = list(src_size_cond) + [target_width, target_height, 0, 0] + image_meta_size = torch.as_tensor([size_cond] * 2 * batch_size, device=self._execution_device) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=latent_model_input.device) + + if use_fp16: + latent_model_input = latent_model_input.half() + t_expand = t_expand.half() + prompt_embeds = prompt_embeds.half() + ims = image_meta_size.half() if image_meta_size is not None else None + else: + ims = image_meta_size if image_meta_size is not None else None + + # predict the noise residual + if self.infer_mode in ["fa", "torch"]: + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=attention_mask, + encoder_hidden_states_t5=prompt_embeds_t5, + text_embedding_mask_t5=attention_mask_t5, + image_meta_size=ims, + style=style, + cos_cis_img=freqs_cis_img[0], + sin_cis_img=freqs_cis_img[1], + return_dict=False, + ) + elif self.infer_mode == "trt": + raise NotImplementedError("TensorRT model is not supported yet.") + else: + raise ValueError("[ERROR] invalid inference mode! please check your config file") + if learn_sigma: + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + results = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True) + latents = results.prev_sample + pred_x0 = results.pred_original_sample if hasattr(results, 'pred_original_sample') else None + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents, pred_x0) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file diff --git a/test_hunyuan_dit.py b/test_hunyuan_dit.py new file mode 100644 index 000000000000..b88863ffd16a --- /dev/null +++ b/test_hunyuan_dit.py @@ -0,0 +1,12 @@ +import torch +from diffusers import HunyuanDiTPipeline + +pipe = HunyuanDiTPipeline.from_pretrained("XCLiu/HunyuanDiT-0523", torch_dtype=torch.float32) +pipe.to('cuda') + +### NOTE: HunyuanDiT supports both Chinese and English inputs +prompt = "一个宇航员在骑马" +#prompt = "An astronaut riding a horse" +image = pipe(height=1024, width=1024, prompt=prompt).images[0] + +image.save("./img.png") \ No newline at end of file