diff --git a/docs/source/en/api/models/prior_transformer.md b/docs/source/en/api/models/prior_transformer.md index 0b849c300662..190b8c9b91d3 100644 --- a/docs/source/en/api/models/prior_transformer.md +++ b/docs/source/en/api/models/prior_transformer.md @@ -24,4 +24,4 @@ The abstract from the paper is: ## PriorTransformerOutput -[[autodoc]] models.prior_transformer.PriorTransformerOutput +[[autodoc]] models.transformers.prior_transformer.PriorTransformerOutput diff --git a/docs/source/en/api/models/transformer2d.md b/docs/source/en/api/models/transformer2d.md index 0f891edd754a..6c427890fd58 100644 --- a/docs/source/en/api/models/transformer2d.md +++ b/docs/source/en/api/models/transformer2d.md @@ -38,4 +38,4 @@ It is assumed one of the input classes is the masked latent pixel. The predicted ## Transformer2DModelOutput -[[autodoc]] models.transformer_2d.Transformer2DModelOutput +[[autodoc]] models.transformers.transformer_2d.Transformer2DModelOutput diff --git a/docs/source/en/api/models/transformer_temporal.md b/docs/source/en/api/models/transformer_temporal.md index c936270b7927..dce222d07e4d 100644 --- a/docs/source/en/api/models/transformer_temporal.md +++ b/docs/source/en/api/models/transformer_temporal.md @@ -16,8 +16,8 @@ A Transformer model for video-like data. ## TransformerTemporalModel -[[autodoc]] models.transformer_temporal.TransformerTemporalModel +[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModel ## TransformerTemporalModelOutput -[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput +[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModelOutput diff --git a/scripts/convert_kakao_brain_unclip_to_diffusers.py b/scripts/convert_kakao_brain_unclip_to_diffusers.py index b02cb498bb9b..5135eaed5b98 100644 --- a/scripts/convert_kakao_brain_unclip_to_diffusers.py +++ b/scripts/convert_kakao_brain_unclip_to_diffusers.py @@ -6,7 +6,7 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel -from diffusers.models.prior_transformer import PriorTransformer +from diffusers.models.transformers.prior_transformer import PriorTransformer from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler diff --git a/scripts/convert_kandinsky_to_diffusers.py b/scripts/convert_kandinsky_to_diffusers.py index 1b5722f5d5f3..8d3f7b63d0c1 100644 --- a/scripts/convert_kandinsky_to_diffusers.py +++ b/scripts/convert_kandinsky_to_diffusers.py @@ -6,7 +6,7 @@ from accelerate import load_checkpoint_and_dispatch from diffusers import UNet2DConditionModel -from diffusers.models.prior_transformer import PriorTransformer +from diffusers.models.transformers.prior_transformer import PriorTransformer from diffusers.models.vq_model import VQModel diff --git a/scripts/convert_shap_e_to_diffusers.py b/scripts/convert_shap_e_to_diffusers.py index cacd2f7ba309..b903b4ee8a7f 100644 --- a/scripts/convert_shap_e_to_diffusers.py +++ b/scripts/convert_shap_e_to_diffusers.py @@ -4,7 +4,7 @@ import torch from accelerate import load_checkpoint_and_dispatch -from diffusers.models.prior_transformer import PriorTransformer +from diffusers.models.transformers.prior_transformer import PriorTransformer from diffusers.pipelines.shap_e import ShapERenderer diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 02c94ddbf1de..30ba97f3868d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -35,10 +35,10 @@ _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] - _import_structure["prior_transformer"] = ["PriorTransformer"] - _import_structure["t5_film_transformer"] = ["T5FilmDecoder"] - _import_structure["transformer_2d"] = ["Transformer2DModel"] - _import_structure["transformer_temporal"] = ["TransformerTemporalModel"] + _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] + _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] + _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] + _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -66,13 +66,15 @@ ConsistencyDecoderVAE, ) from .controlnet import ControlNetModel - from .dual_transformer_2d import DualTransformer2DModel from .embeddings import ImageProjection from .modeling_utils import ModelMixin - from .prior_transformer import PriorTransformer - from .t5_film_transformer import T5FilmDecoder - from .transformer_2d import Transformer2DModel - from .transformer_temporal import TransformerTemporalModel + from .transformers import ( + DualTransformer2DModel, + PriorTransformer, + T5FilmDecoder, + Transformer2DModel, + TransformerTemporalModel, + ) from .unets import ( Kandinsky3UNet, MotionAdapter, diff --git a/src/diffusers/models/dual_transformer_2d.py b/src/diffusers/models/dual_transformer_2d.py index 21b135c2eb86..1986aa543cd6 100644 --- a/src/diffusers/models/dual_transformer_2d.py +++ b/src/diffusers/models/dual_transformer_2d.py @@ -11,145 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from ..utils import deprecate +from .transformers.dual_transformer_2d import DualTransformer2DModel -from torch import nn -from .transformer_2d import Transformer2DModel, Transformer2DModelOutput - - -class DualTransformer2DModel(nn.Module): - """ - Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. - sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. - Note that this is fixed at training time as it is used for learning a number of position embeddings. See - `ImagePositionalEmbeddings`. - num_vector_embeds (`int`, *optional*): - Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. - The number of diffusion steps used during training. Note that this is fixed at training time as it is used - to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for - up to but not more than steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. - """ - - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - num_vector_embeds: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - ): - super().__init__() - self.transformers = nn.ModuleList( - [ - Transformer2DModel( - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - in_channels=in_channels, - num_layers=num_layers, - dropout=dropout, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attention_bias=attention_bias, - sample_size=sample_size, - num_vector_embeds=num_vector_embeds, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - ) - for _ in range(2) - ] - ) - - # Variables that can be set by a pipeline: - - # The ratio of transformer1 to transformer2's output states to be combined during inference - self.mix_ratio = 0.5 - - # The shape of `encoder_hidden_states` is expected to be - # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` - self.condition_lengths = [77, 257] - - # Which transformer to use to encode which condition. - # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` - self.transformer_index_for_condition = [1, 0] - - def forward( - self, - hidden_states, - encoder_hidden_states, - timestep=None, - attention_mask=None, - cross_attention_kwargs=None, - return_dict: bool = True, - ): - """ - Args: - hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. - When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input - hidden_states. - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.long`, *optional*): - Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. - attention_mask (`torch.FloatTensor`, *optional*): - Optional attention mask to be applied in Attention. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: - [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - input_states = hidden_states - - encoded_states = [] - tokens_start = 0 - # attention_mask is not used yet - for i in range(2): - # for each of the two transformers, pass the corresponding condition tokens - condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] - transformer_index = self.transformer_index_for_condition[i] - encoded_state = self.transformers[transformer_index]( - input_states, - encoder_hidden_states=condition_state, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=False, - )[0] - encoded_states.append(encoded_state - input_states) - tokens_start += self.condition_lengths[i] - - output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) - output_states = output_states + input_states - - if not return_dict: - return (output_states,) - - return Transformer2DModelOutput(sample=output_states) +class DualTransformer2DModel(DualTransformer2DModel): + deprecation_message = "Importing `DualTransformer2DModel` from `diffusers.models.dual_transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel`, instead." + deprecate("DualTransformer2DModel", "0.29", deprecation_message) diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 081d66991faf..328835a95381 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -1,380 +1,12 @@ -from dataclasses import dataclass -from typing import Dict, Optional, Union +from ..utils import deprecate +from .transformers.prior_transformer import PriorTransformer, PriorTransformerOutput -import torch -import torch.nn.functional as F -from torch import nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin -from ..utils import BaseOutput -from .attention import BasicTransformerBlock -from .attention_processor import ( - ADDED_KV_ATTENTION_PROCESSORS, - CROSS_ATTENTION_PROCESSORS, - AttentionProcessor, - AttnAddedKVProcessor, - AttnProcessor, -) -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin +class PriorTransformerOutput(PriorTransformerOutput): + deprecation_message = "Importing `PriorTransformerOutput` from `diffusers.models.prior_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.prior_transformer import PriorTransformerOutput`, instead." + deprecate("PriorTransformerOutput", "0.29", deprecation_message) -@dataclass -class PriorTransformerOutput(BaseOutput): - """ - The output of [`PriorTransformer`]. - - Args: - predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): - The predicted CLIP image embedding conditioned on the CLIP text embedding input. - """ - - predicted_image_embedding: torch.FloatTensor - - -class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): - """ - A Prior Transformer model. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. - embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` - num_embeddings (`int`, *optional*, defaults to 77): - The number of embeddings of the model input `hidden_states` - additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the - projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + - additional_embeddings`. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - time_embed_act_fn (`str`, *optional*, defaults to 'silu'): - The activation function to use to create timestep embeddings. - norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before - passing to Transformer blocks. Set it to `None` if normalization is not needed. - embedding_proj_norm_type (`str`, *optional*, defaults to None): - The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not - needed. - encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): - The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if - `encoder_hidden_states` is `None`. - added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. - Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot - product between the text embedding and image embedding as proposed in the unclip paper - https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. - time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. - If None, will be set to `num_attention_heads * attention_head_dim` - embedding_proj_dim (`int`, *optional*, default to None): - The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. - clip_embed_dim (`int`, *optional*, default to None): - The dimension of the output. If None, will be set to `embedding_dim`. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 32, - attention_head_dim: int = 64, - num_layers: int = 20, - embedding_dim: int = 768, - num_embeddings=77, - additional_embeddings=4, - dropout: float = 0.0, - time_embed_act_fn: str = "silu", - norm_in_type: Optional[str] = None, # layer - embedding_proj_norm_type: Optional[str] = None, # layer - encoder_hid_proj_type: Optional[str] = "linear", # linear - added_emb_type: Optional[str] = "prd", # prd - time_embed_dim: Optional[int] = None, - embedding_proj_dim: Optional[int] = None, - clip_embed_dim: Optional[int] = None, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - self.additional_embeddings = additional_embeddings - - time_embed_dim = time_embed_dim or inner_dim - embedding_proj_dim = embedding_proj_dim or embedding_dim - clip_embed_dim = clip_embed_dim or embedding_dim - - self.time_proj = Timesteps(inner_dim, True, 0) - self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) - - self.proj_in = nn.Linear(embedding_dim, inner_dim) - - if embedding_proj_norm_type is None: - self.embedding_proj_norm = None - elif embedding_proj_norm_type == "layer": - self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim) - else: - raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") - - self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim) - - if encoder_hid_proj_type is None: - self.encoder_hidden_states_proj = None - elif encoder_hid_proj_type == "linear": - self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) - else: - raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") - - self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) - - if added_emb_type == "prd": - self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) - elif added_emb_type is None: - self.prd_embedding = None - else: - raise ValueError( - f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - activation_fn="gelu", - attention_bias=True, - ) - for d in range(num_layers) - ] - ) - - if norm_in_type == "layer": - self.norm_in = nn.LayerNorm(inner_dim) - elif norm_in_type is None: - self.norm_in = None - else: - raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") - - self.norm_out = nn.LayerNorm(inner_dim) - - self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim) - - causal_attention_mask = torch.full( - [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 - ) - causal_attention_mask.triu_(1) - causal_attention_mask = causal_attention_mask[None, ...] - self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) - - self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) - self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnAddedKVProcessor() - elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) - - self.set_attn_processor(processor) - - def forward( - self, - hidden_states, - timestep: Union[torch.Tensor, float, int], - proj_embedding: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - return_dict: bool = True, - ): - """ - The [`PriorTransformer`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): - The currently predicted image embeddings. - timestep (`torch.LongTensor`): - Current denoising step. - proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): - Projected embedding vector the denoising process is conditioned on. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): - Hidden states of the text embeddings the denoising process is conditioned on. - attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): - Text mask for the text embeddings. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain - tuple. - - Returns: - [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: - If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - """ - batch_size = hidden_states.shape[0] - - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(hidden_states.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) - - timesteps_projected = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might be fp16, so we need to cast here. - timesteps_projected = timesteps_projected.to(dtype=self.dtype) - time_embeddings = self.time_embedding(timesteps_projected) - - if self.embedding_proj_norm is not None: - proj_embedding = self.embedding_proj_norm(proj_embedding) - - proj_embeddings = self.embedding_proj(proj_embedding) - if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: - encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) - elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: - raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") - - hidden_states = self.proj_in(hidden_states) - - positional_embeddings = self.positional_embedding.to(hidden_states.dtype) - - additional_embeds = [] - additional_embeddings_len = 0 - - if encoder_hidden_states is not None: - additional_embeds.append(encoder_hidden_states) - additional_embeddings_len += encoder_hidden_states.shape[1] - - if len(proj_embeddings.shape) == 2: - proj_embeddings = proj_embeddings[:, None, :] - - if len(hidden_states.shape) == 2: - hidden_states = hidden_states[:, None, :] - - additional_embeds = additional_embeds + [ - proj_embeddings, - time_embeddings[:, None, :], - hidden_states, - ] - - if self.prd_embedding is not None: - prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) - additional_embeds.append(prd_embedding) - - hidden_states = torch.cat( - additional_embeds, - dim=1, - ) - - # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens - additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 - if positional_embeddings.shape[1] < hidden_states.shape[1]: - positional_embeddings = F.pad( - positional_embeddings, - ( - 0, - 0, - additional_embeddings_len, - self.prd_embedding.shape[1] if self.prd_embedding is not None else 0, - ), - value=0.0, - ) - - hidden_states = hidden_states + positional_embeddings - - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 - attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) - attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) - attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) - - if self.norm_in is not None: - hidden_states = self.norm_in(hidden_states) - - for block in self.transformer_blocks: - hidden_states = block(hidden_states, attention_mask=attention_mask) - - hidden_states = self.norm_out(hidden_states) - - if self.prd_embedding is not None: - hidden_states = hidden_states[:, -1] - else: - hidden_states = hidden_states[:, additional_embeddings_len:] - - predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) - - if not return_dict: - return (predicted_image_embedding,) - - return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) - - def post_process_latents(self, prior_latents): - prior_latents = (prior_latents * self.clip_std) + self.clip_mean - return prior_latents +class PriorTransformer(PriorTransformer): + deprecation_message = "Importing `PriorTransformer` from `diffusers.models.prior_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.prior_transformer import PriorTransformer`, instead." + deprecate("PriorTransformer", "0.29", deprecation_message) diff --git a/src/diffusers/models/t5_film_transformer.py b/src/diffusers/models/t5_film_transformer.py index 26ff3f6b8127..d06b5c9ec3c9 100644 --- a/src/diffusers/models/t5_film_transformer.py +++ b/src/diffusers/models/t5_film_transformer.py @@ -11,428 +11,60 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Optional, Tuple +from ..utils import deprecate +from .transformers.t5_film_transformer import ( + DecoderLayer, + NewGELUActivation, + T5DenseGatedActDense, + T5FilmDecoder, + T5FiLMLayer, + T5LayerCrossAttention, + T5LayerFFCond, + T5LayerNorm, + T5LayerSelfAttentionCond, +) -import torch -from torch import nn -from ..configuration_utils import ConfigMixin, register_to_config -from .attention_processor import Attention -from .embeddings import get_timestep_embedding -from .modeling_utils import ModelMixin +class T5FilmDecoder(T5FilmDecoder): + deprecation_message = "Importing `T5FilmDecoder` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5FilmDecoder`, instead." + deprecate("T5FilmDecoder", "0.29", deprecation_message) -class T5FilmDecoder(ModelMixin, ConfigMixin): - r""" - T5 style decoder with FiLM conditioning. +class DecoderLayer(DecoderLayer): + deprecation_message = "Importing `DecoderLayer` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import DecoderLayer`, instead." + deprecate("DecoderLayer", "0.29", deprecation_message) - Args: - input_dims (`int`, *optional*, defaults to `128`): - The number of input dimensions. - targets_length (`int`, *optional*, defaults to `256`): - The length of the targets. - d_model (`int`, *optional*, defaults to `768`): - Size of the input hidden states. - num_layers (`int`, *optional*, defaults to `12`): - The number of `DecoderLayer`'s to use. - num_heads (`int`, *optional*, defaults to `12`): - The number of attention heads to use. - d_kv (`int`, *optional*, defaults to `64`): - Size of the key-value projection vectors. - d_ff (`int`, *optional*, defaults to `2048`): - The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s. - dropout_rate (`float`, *optional*, defaults to `0.1`): - Dropout probability. - """ - @register_to_config - def __init__( - self, - input_dims: int = 128, - targets_length: int = 256, - max_decoder_noise_time: float = 2000.0, - d_model: int = 768, - num_layers: int = 12, - num_heads: int = 12, - d_kv: int = 64, - d_ff: int = 2048, - dropout_rate: float = 0.1, - ): - super().__init__() +class T5LayerSelfAttentionCond(T5LayerSelfAttentionCond): + deprecation_message = "Importing `T5LayerSelfAttentionCond` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerSelfAttentionCond`, instead." + deprecate("T5LayerSelfAttentionCond", "0.29", deprecation_message) - self.conditioning_emb = nn.Sequential( - nn.Linear(d_model, d_model * 4, bias=False), - nn.SiLU(), - nn.Linear(d_model * 4, d_model * 4, bias=False), - nn.SiLU(), - ) - self.position_encoding = nn.Embedding(targets_length, d_model) - self.position_encoding.weight.requires_grad = False +class T5LayerCrossAttention(T5LayerCrossAttention): + deprecation_message = "Importing `T5LayerCrossAttention` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerCrossAttention`, instead." + deprecate("T5LayerCrossAttention", "0.29", deprecation_message) - self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False) - self.dropout = nn.Dropout(p=dropout_rate) +class T5LayerFFCond(T5LayerFFCond): + deprecation_message = "Importing `T5LayerFFCond` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerFFCond`, instead." + deprecate("T5LayerFFCond", "0.29", deprecation_message) - self.decoders = nn.ModuleList() - for lyr_num in range(num_layers): - # FiLM conditional T5 decoder - lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) - self.decoders.append(lyr) - self.decoder_norm = T5LayerNorm(d_model) +class T5DenseGatedActDense(T5DenseGatedActDense): + deprecation_message = "Importing `T5DenseGatedActDense` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5DenseGatedActDense`, instead." + deprecate("T5DenseGatedActDense", "0.29", deprecation_message) - self.post_dropout = nn.Dropout(p=dropout_rate) - self.spec_out = nn.Linear(d_model, input_dims, bias=False) - def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor: - mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) - return mask.unsqueeze(-3) +class T5LayerNorm(T5LayerNorm): + deprecation_message = "Importing `T5LayerNorm` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerNorm`, instead." + deprecate("T5LayerNorm", "0.29", deprecation_message) - def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): - batch, _, _ = decoder_input_tokens.shape - assert decoder_noise_time.shape == (batch,) - # decoder_noise_time is in [0, 1), so rescale to expected timing range. - time_steps = get_timestep_embedding( - decoder_noise_time * self.config.max_decoder_noise_time, - embedding_dim=self.config.d_model, - max_period=self.config.max_decoder_noise_time, - ).to(dtype=self.dtype) +class NewGELUActivation(NewGELUActivation): + deprecation_message = "Importing `T5LayerNorm` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import NewGELUActivation`, instead." + deprecate("NewGELUActivation", "0.29", deprecation_message) - conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) - assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4) - - seq_length = decoder_input_tokens.shape[1] - - # If we want to use relative positions for audio context, we can just offset - # this sequence by the length of encodings_and_masks. - decoder_positions = torch.broadcast_to( - torch.arange(seq_length, device=decoder_input_tokens.device), - (batch, seq_length), - ) - - position_encodings = self.position_encoding(decoder_positions) - - inputs = self.continuous_inputs_projection(decoder_input_tokens) - inputs += position_encodings - y = self.dropout(inputs) - - # decoder: No padding present. - decoder_mask = torch.ones( - decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype - ) - - # Translate encoding masks to encoder-decoder masks. - encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] - - # cross attend style: concat encodings - encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1) - encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1) - - for lyr in self.decoders: - y = lyr( - y, - conditioning_emb=conditioning_emb, - encoder_hidden_states=encoded, - encoder_attention_mask=encoder_decoder_mask, - )[0] - - y = self.decoder_norm(y) - y = self.post_dropout(y) - - spec_out = self.spec_out(y) - return spec_out - - -class DecoderLayer(nn.Module): - r""" - T5 decoder layer. - - Args: - d_model (`int`): - Size of the input hidden states. - d_kv (`int`): - Size of the key-value projection vectors. - num_heads (`int`): - Number of attention heads. - d_ff (`int`): - Size of the intermediate feed-forward layer. - dropout_rate (`float`): - Dropout probability. - layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`): - A small value used for numerical stability to avoid dividing by zero. - """ - - def __init__( - self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6 - ): - super().__init__() - self.layer = nn.ModuleList() - - # cond self attention: layer 0 - self.layer.append( - T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) - ) - - # cross attention: layer 1 - self.layer.append( - T5LayerCrossAttention( - d_model=d_model, - d_kv=d_kv, - num_heads=num_heads, - dropout_rate=dropout_rate, - layer_norm_epsilon=layer_norm_epsilon, - ) - ) - - # Film Cond MLP + dropout: last layer - self.layer.append( - T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) - ) - - def forward( - self, - hidden_states: torch.FloatTensor, - conditioning_emb: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - encoder_decoder_position_bias=None, - ) -> Tuple[torch.FloatTensor]: - hidden_states = self.layer[0]( - hidden_states, - conditioning_emb=conditioning_emb, - attention_mask=attention_mask, - ) - - if encoder_hidden_states is not None: - encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to( - encoder_hidden_states.dtype - ) - - hidden_states = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_extended_attention_mask, - ) - - # Apply Film Conditional Feed Forward layer - hidden_states = self.layer[-1](hidden_states, conditioning_emb) - - return (hidden_states,) - - -class T5LayerSelfAttentionCond(nn.Module): - r""" - T5 style self-attention layer with conditioning. - - Args: - d_model (`int`): - Size of the input hidden states. - d_kv (`int`): - Size of the key-value projection vectors. - num_heads (`int`): - Number of attention heads. - dropout_rate (`float`): - Dropout probability. - """ - - def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float): - super().__init__() - self.layer_norm = T5LayerNorm(d_model) - self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) - self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) - self.dropout = nn.Dropout(dropout_rate) - - def forward( - self, - hidden_states: torch.FloatTensor, - conditioning_emb: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - # pre_self_attention_layer_norm - normed_hidden_states = self.layer_norm(hidden_states) - - if conditioning_emb is not None: - normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) - - # Self-attention block - attention_output = self.attention(normed_hidden_states) - - hidden_states = hidden_states + self.dropout(attention_output) - - return hidden_states - - -class T5LayerCrossAttention(nn.Module): - r""" - T5 style cross-attention layer. - - Args: - d_model (`int`): - Size of the input hidden states. - d_kv (`int`): - Size of the key-value projection vectors. - num_heads (`int`): - Number of attention heads. - dropout_rate (`float`): - Dropout probability. - layer_norm_epsilon (`float`): - A small value used for numerical stability to avoid dividing by zero. - """ - - def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float): - super().__init__() - self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) - self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) - self.dropout = nn.Dropout(dropout_rate) - - def forward( - self, - hidden_states: torch.FloatTensor, - key_value_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.attention( - normed_hidden_states, - encoder_hidden_states=key_value_states, - attention_mask=attention_mask.squeeze(1), - ) - layer_output = hidden_states + self.dropout(attention_output) - return layer_output - - -class T5LayerFFCond(nn.Module): - r""" - T5 style feed-forward conditional layer. - - Args: - d_model (`int`): - Size of the input hidden states. - d_ff (`int`): - Size of the intermediate feed-forward layer. - dropout_rate (`float`): - Dropout probability. - layer_norm_epsilon (`float`): - A small value used for numerical stability to avoid dividing by zero. - """ - - def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float): - super().__init__() - self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) - self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) - self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) - self.dropout = nn.Dropout(dropout_rate) - - def forward( - self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None - ) -> torch.FloatTensor: - forwarded_states = self.layer_norm(hidden_states) - if conditioning_emb is not None: - forwarded_states = self.film(forwarded_states, conditioning_emb) - - forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = hidden_states + self.dropout(forwarded_states) - return hidden_states - - -class T5DenseGatedActDense(nn.Module): - r""" - T5 style feed-forward layer with gated activations and dropout. - - Args: - d_model (`int`): - Size of the input hidden states. - d_ff (`int`): - Size of the intermediate feed-forward layer. - dropout_rate (`float`): - Dropout probability. - """ - - def __init__(self, d_model: int, d_ff: int, dropout_rate: float): - super().__init__() - self.wi_0 = nn.Linear(d_model, d_ff, bias=False) - self.wi_1 = nn.Linear(d_model, d_ff, bias=False) - self.wo = nn.Linear(d_ff, d_model, bias=False) - self.dropout = nn.Dropout(dropout_rate) - self.act = NewGELUActivation() - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states) - - hidden_states = self.wo(hidden_states) - return hidden_states - - -class T5LayerNorm(nn.Module): - r""" - T5 style layer normalization module. - - Args: - hidden_size (`int`): - Size of the input hidden states. - eps (`float`, `optional`, defaults to `1e-6`): - A small value used for numerical stability to avoid dividing by zero. - """ - - def __init__(self, hidden_size: int, eps: float = 1e-6): - """ - Construct a layernorm module in the T5 style. No bias and no subtraction of mean. - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 - - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -class NewGELUActivation(nn.Module): - """ - Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see - the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) - - -class T5FiLMLayer(nn.Module): - """ - T5 style FiLM Layer. - - Args: - in_features (`int`): - Number of input features. - out_features (`int`): - Number of output features. - """ - - def __init__(self, in_features: int, out_features: int): - super().__init__() - self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) - - def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor: - emb = self.scale_bias(conditioning_emb) - scale, shift = torch.chunk(emb, 2, -1) - x = x * (1 + scale) + shift - return x +class T5FiLMLayer(T5FiLMLayer): + deprecation_message = "Importing `T5FiLMLayer` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5FiLMLayer`, instead." + deprecate("T5FiLMLayer", "0.29", deprecation_message) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 3b219b4f0b37..29f27e211bda 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -11,449 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, Optional +from ..utils import deprecate +from .transformers.transformer_2d import Transformer2DModel, Transformer2DModelOutput -import torch -import torch.nn.functional as F -from torch import nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..models.embeddings import ImagePositionalEmbeddings -from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version -from .attention import BasicTransformerBlock -from .embeddings import PatchEmbed, PixArtAlphaTextProjection -from .lora import LoRACompatibleConv, LoRACompatibleLinear -from .modeling_utils import ModelMixin -from .normalization import AdaLayerNormSingle +class Transformer2DModelOutput(Transformer2DModelOutput): + deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput`, instead." + deprecate("Transformer2DModelOutput", "0.29", deprecation_message) -@dataclass -class Transformer2DModelOutput(BaseOutput): - """ - The output of [`Transformer2DModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability - distributions for the unnoised latent pixels. - """ - - sample: torch.FloatTensor - - -class Transformer2DModel(ModelMixin, ConfigMixin): - """ - A 2D Transformer model for image-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - num_vector_embeds (`int`, *optional*): - The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): - The number of diffusion steps used during training. Pass if at least one of the norm_layers is - `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are - added to the hidden states. - - During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlocks` attention should contain a bias parameter. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - num_vector_embeds: Optional[int] = None, - patch_size: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_type: str = "layer_norm", - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - attention_type: str = "default", - caption_channels: int = None, - ): - super().__init__() - self.use_linear_projection = use_linear_projection - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear - - # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` - # Define whether input is continuous or discrete depending on configuration - self.is_input_continuous = (in_channels is not None) and (patch_size is None) - self.is_input_vectorized = num_vector_embeds is not None - self.is_input_patches = in_channels is not None and patch_size is not None - - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" - ) - deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) - norm_type = "ada_norm" - - if self.is_input_continuous and self.is_input_vectorized: - raise ValueError( - f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" - " sure that either `in_channels` or `num_vector_embeds` is None." - ) - elif self.is_input_vectorized and self.is_input_patches: - raise ValueError( - f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" - " sure that either `num_vector_embeds` or `num_patches` is None." - ) - elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: - raise ValueError( - f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" - f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." - ) - - # 2. Define input layers - if self.is_input_continuous: - self.in_channels = in_channels - - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - if use_linear_projection: - self.proj_in = linear_cls(in_channels, inner_dim) - else: - self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - elif self.is_input_vectorized: - assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" - assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" - - self.height = sample_size - self.width = sample_size - self.num_vector_embeds = num_vector_embeds - self.num_latent_pixels = self.height * self.width - - self.latent_image_embedding = ImagePositionalEmbeddings( - num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width - ) - elif self.is_input_patches: - assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" - - self.height = sample_size - self.width = sample_size - - self.patch_size = patch_size - interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 - interpolation_scale = max(interpolation_scale, 1) - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - interpolation_scale=interpolation_scale, - ) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - double_self_attention=double_self_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - attention_type=attention_type, - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - if self.is_input_continuous: - # TODO: should use out_channels for continuous projections - if use_linear_projection: - self.proj_out = linear_cls(inner_dim, in_channels) - else: - self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - elif self.is_input_vectorized: - self.norm_out = nn.LayerNorm(inner_dim) - self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - elif self.is_input_patches and norm_type != "ada_norm_single": - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) - self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) - elif self.is_input_patches and norm_type == "ada_norm_single": - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) - - # 5. PixArt-Alpha blocks. - self.adaln_single = None - self.use_additional_conditions = False - if norm_type == "ada_norm_single": - self.use_additional_conditions = self.config.sample_size == 128 - # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use - # additional conditions until we find better name - self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) - - self.caption_projection = None - if caption_channels is not None: - self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) - - self.gradient_checkpointing = False - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, - added_cond_kwargs: Dict[str, torch.Tensor] = None, - class_labels: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - ): - """ - The [`Transformer2DModel`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): - Input `hidden_states`. - encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - cross_attention_kwargs ( `Dict[str, Any]`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - attention_mask ( `torch.Tensor`, *optional*): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - encoder_attention_mask ( `torch.Tensor`, *optional*): - Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: - - * Mask `(batch, sequence_length)` True = keep, False = discard. - * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. - - If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format - above. This bias will be added to the cross-attention scores. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. - # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. - # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None and attention_mask.ndim == 2: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - # Retrieve lora scale. - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - - # 1. Input - if self.is_input_continuous: - batch, _, height, width = hidden_states.shape - residual = hidden_states - - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = ( - self.proj_in(hidden_states, scale=lora_scale) - if not USE_PEFT_BACKEND - else self.proj_in(hidden_states) - ) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = ( - self.proj_in(hidden_states, scale=lora_scale) - if not USE_PEFT_BACKEND - else self.proj_in(hidden_states) - ) - - elif self.is_input_vectorized: - hidden_states = self.latent_image_embedding(hidden_states) - elif self.is_input_patches: - height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size - hidden_states = self.pos_embed(hidden_states) - - if self.adaln_single is not None: - if self.use_additional_conditions and added_cond_kwargs is None: - raise ValueError( - "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." - ) - batch_size = hidden_states.shape[0] - timestep, embedded_timestep = self.adaln_single( - timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype - ) - - # 2. Blocks - if self.caption_projection is not None: - batch_size = hidden_states.shape[0] - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) - - for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - timestep, - cross_attention_kwargs, - class_labels, - **ckpt_kwargs, - ) - else: - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - - # 3. Output - if self.is_input_continuous: - if not self.use_linear_projection: - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = ( - self.proj_out(hidden_states, scale=lora_scale) - if not USE_PEFT_BACKEND - else self.proj_out(hidden_states) - ) - else: - hidden_states = ( - self.proj_out(hidden_states, scale=lora_scale) - if not USE_PEFT_BACKEND - else self.proj_out(hidden_states) - ) - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - - output = hidden_states + residual - elif self.is_input_vectorized: - hidden_states = self.norm_out(hidden_states) - logits = self.out(hidden_states) - # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) - logits = logits.permute(0, 2, 1) - - # log(p(x_0)) - output = F.log_softmax(logits.double(), dim=1).float() - - if self.is_input_patches: - if self.config.norm_type != "ada_norm_single": - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - hidden_states = self.proj_out_2(hidden_states) - elif self.config.norm_type == "ada_norm_single": - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.squeeze(1) - - # unpatchify - if self.adaln_single is None: - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) - ) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) +class Transformer2DModel(Transformer2DModel): + deprecation_message = "Importing `Transformer2DModel` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_2d import Transformer2DModel`, instead." + deprecate("Transformer2DModel", "0.29", deprecation_message) diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index a18671776baf..480f8faf9a27 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -11,369 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, Optional +from ..utils import deprecate +from .transformers.transformer_temporal import ( + TransformerSpatioTemporalModel, + TransformerTemporalModel, + TransformerTemporalModelOutput, +) -import torch -from torch import nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .resnet import AlphaBlender +class TransformerTemporalModelOutput(TransformerTemporalModelOutput): + deprecation_message = "Importing `TransformerTemporalModelOutput` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModelOutput`, instead." + deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message) -@dataclass -class TransformerTemporalModelOutput(BaseOutput): - """ - The output of [`TransformerTemporalModel`]. +class TransformerTemporalModel(TransformerTemporalModel): + deprecation_message = "Importing `TransformerTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModel`, instead." + deprecate("TransformerTemporalModel", "0.29", deprecation_message) - Args: - sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): - The hidden states output conditioned on `encoder_hidden_states` input. - """ - sample: torch.FloatTensor - - -class TransformerTemporalModel(ModelMixin, ConfigMixin): - """ - A Transformer model for video-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlock` attention should contain a bias parameter. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - activation_fn (`str`, *optional*, defaults to `"geglu"`): - Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported - activation functions. - norm_elementwise_affine (`bool`, *optional*): - Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. - double_self_attention (`bool`, *optional*): - Configure if each `TransformerBlock` should contain two self-attention layers. - positional_embeddings: (`str`, *optional*): - The type of positional embeddings to apply to the sequence input before passing use. - num_positional_embeddings: (`int`, *optional*): - The maximum length of the sequence over which to apply positional embeddings. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - activation_fn: str = "geglu", - norm_elementwise_affine: bool = True, - double_self_attention: bool = True, - positional_embeddings: Optional[str] = None, - num_positional_embeddings: Optional[int] = None, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - self.in_channels = in_channels - - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - self.proj_in = nn.Linear(in_channels, inner_dim) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - attention_bias=attention_bias, - double_self_attention=double_self_attention, - norm_elementwise_affine=norm_elementwise_affine, - positional_embeddings=positional_embeddings, - num_positional_embeddings=num_positional_embeddings, - ) - for d in range(num_layers) - ] - ) - - self.proj_out = nn.Linear(inner_dim, in_channels) - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.LongTensor] = None, - timestep: Optional[torch.LongTensor] = None, - class_labels: torch.LongTensor = None, - num_frames: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> TransformerTemporalModelOutput: - """ - The [`TransformerTemporal`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): - Input hidden_states. - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - num_frames (`int`, *optional*, defaults to 1): - The number of frames to be processed per batch. This is used to reshape the hidden states. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: - If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is - returned, otherwise a `tuple` where the first element is the sample tensor. - """ - # 1. Input - batch_frames, channel, height, width = hidden_states.shape - batch_size = batch_frames // num_frames - - residual = hidden_states - - hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) - hidden_states = hidden_states.permute(0, 2, 1, 3, 4) - - hidden_states = self.norm(hidden_states) - hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) - - hidden_states = self.proj_in(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - - # 3. Output - hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states[None, None, :] - .reshape(batch_size, height, width, num_frames, channel) - .permute(0, 3, 4, 1, 2) - .contiguous() - ) - hidden_states = hidden_states.reshape(batch_frames, channel, height, width) - - output = hidden_states + residual - - if not return_dict: - return (output,) - - return TransformerTemporalModelOutput(sample=output) - - -class TransformerSpatioTemporalModel(nn.Module): - """ - A Transformer model for video-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - out_channels (`int`, *optional*): - The number of channels in the output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - """ - - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: int = 320, - out_channels: Optional[int] = None, - num_layers: int = 1, - cross_attention_dim: Optional[int] = None, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - - inner_dim = num_attention_heads * attention_head_dim - self.inner_dim = inner_dim - - # 2. Define input layers - self.in_channels = in_channels - self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) - self.proj_in = nn.Linear(in_channels, inner_dim) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - cross_attention_dim=cross_attention_dim, - ) - for d in range(num_layers) - ] - ) - - time_mix_inner_dim = inner_dim - self.temporal_transformer_blocks = nn.ModuleList( - [ - TemporalBasicTransformerBlock( - inner_dim, - time_mix_inner_dim, - num_attention_heads, - attention_head_dim, - cross_attention_dim=cross_attention_dim, - ) - for _ in range(num_layers) - ] - ) - - time_embed_dim = in_channels * 4 - self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) - self.time_proj = Timesteps(in_channels, True, 0) - self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") - - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - # TODO: should use out_channels for continuous projections - self.proj_out = nn.Linear(inner_dim, in_channels) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - image_only_indicator: Optional[torch.Tensor] = None, - return_dict: bool = True, - ): - """ - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): - Input hidden_states. - num_frames (`int`): - The number of frames to be processed per batch. This is used to reshape the hidden states. - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*): - A tensor indicating whether the input contains only images. 1 indicates that the input contains only - images, 0 indicates that the input contains video frames. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain - tuple. - - Returns: - [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: - If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is - returned, otherwise a `tuple` where the first element is the sample tensor. - """ - # 1. Input - batch_frames, _, height, width = hidden_states.shape - num_frames = image_only_indicator.shape[-1] - batch_size = batch_frames // num_frames - - time_context = encoder_hidden_states - time_context_first_timestep = time_context[None, :].reshape( - batch_size, num_frames, -1, time_context.shape[-1] - )[:, 0] - time_context = time_context_first_timestep[None, :].broadcast_to( - height * width, batch_size, 1, time_context.shape[-1] - ) - time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1]) - - residual = hidden_states - - hidden_states = self.norm(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states) - - num_frames_emb = torch.arange(num_frames, device=hidden_states.device) - num_frames_emb = num_frames_emb.repeat(batch_size, 1) - num_frames_emb = num_frames_emb.reshape(-1) - t_emb = self.time_proj(num_frames_emb) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=hidden_states.dtype) - - emb = self.time_pos_embed(t_emb) - emb = emb[:, None, :] - - # 2. Blocks - for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): - if self.training and self.gradient_checkpointing: - hidden_states = torch.utils.checkpoint.checkpoint( - block, - hidden_states, - None, - encoder_hidden_states, - None, - use_reentrant=False, - ) - else: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - ) - - hidden_states_mix = hidden_states - hidden_states_mix = hidden_states_mix + emb - - hidden_states_mix = temporal_block( - hidden_states_mix, - num_frames=num_frames, - encoder_hidden_states=time_context, - ) - hidden_states = self.time_mixer( - x_spatial=hidden_states, - x_temporal=hidden_states_mix, - image_only_indicator=image_only_indicator, - ) - - # 3. Output - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - - output = hidden_states + residual - - if not return_dict: - return (output,) - - return TransformerTemporalModelOutput(sample=output) +class TransformerSpatioTemporalModel(TransformerSpatioTemporalModel): + deprecation_message = "Importing `TransformerSpatioTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerSpatioTemporalModel`, instead." + deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py new file mode 100644 index 000000000000..dc78a72b2fb8 --- /dev/null +++ b/src/diffusers/models/transformers/__init__.py @@ -0,0 +1,9 @@ +from ...utils import is_torch_available + + +if is_torch_available(): + from .dual_transformer_2d import DualTransformer2DModel + from .prior_transformer import PriorTransformer + from .t5_film_transformer import T5FilmDecoder + from .transformer_2d import Transformer2DModel + from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/dual_transformer_2d.py b/src/diffusers/models/transformers/dual_transformer_2d.py new file mode 100644 index 000000000000..21b135c2eb86 --- /dev/null +++ b/src/diffusers/models/transformers/dual_transformer_2d.py @@ -0,0 +1,155 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from torch import nn + +from .transformer_2d import Transformer2DModel, Transformer2DModelOutput + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # Variables that can be set by a pipeline: + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` + self.condition_lengths = [77, 257] + + # Which transformer to use to encode which condition. + # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` + self.transformer_index_for_condition = [1, 0] + + def forward( + self, + hidden_states, + encoder_hidden_states, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + attention_mask (`torch.FloatTensor`, *optional*): + Optional attention mask to be applied in Attention. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + # attention_mask is not used yet + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index]( + input_states, + encoder_hidden_states=condition_state, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py new file mode 100644 index 000000000000..990eabe2c37a --- /dev/null +++ b/src/diffusers/models/transformers/prior_transformer.py @@ -0,0 +1,380 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...utils import BaseOutput +from ..attention import BasicTransformerBlock +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin + + +@dataclass +class PriorTransformerOutput(BaseOutput): + """ + The output of [`PriorTransformer`]. + + Args: + predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + The predicted CLIP image embedding conditioned on the CLIP text embedding input. + """ + + predicted_image_embedding: torch.FloatTensor + + +class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + """ + A Prior Transformer model. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` + num_embeddings (`int`, *optional*, defaults to 77): + The number of embeddings of the model input `hidden_states` + additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the + projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + + additional_embeddings`. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + time_embed_act_fn (`str`, *optional*, defaults to 'silu'): + The activation function to use to create timestep embeddings. + norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before + passing to Transformer blocks. Set it to `None` if normalization is not needed. + embedding_proj_norm_type (`str`, *optional*, defaults to None): + The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not + needed. + encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): + The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if + `encoder_hidden_states` is `None`. + added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. + Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot + product between the text embedding and image embedding as proposed in the unclip paper + https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. + time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. + If None, will be set to `num_attention_heads * attention_head_dim` + embedding_proj_dim (`int`, *optional*, default to None): + The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. + clip_embed_dim (`int`, *optional*, default to None): + The dimension of the output. If None, will be set to `embedding_dim`. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 32, + attention_head_dim: int = 64, + num_layers: int = 20, + embedding_dim: int = 768, + num_embeddings=77, + additional_embeddings=4, + dropout: float = 0.0, + time_embed_act_fn: str = "silu", + norm_in_type: Optional[str] = None, # layer + embedding_proj_norm_type: Optional[str] = None, # layer + encoder_hid_proj_type: Optional[str] = "linear", # linear + added_emb_type: Optional[str] = "prd", # prd + time_embed_dim: Optional[int] = None, + embedding_proj_dim: Optional[int] = None, + clip_embed_dim: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.additional_embeddings = additional_embeddings + + time_embed_dim = time_embed_dim or inner_dim + embedding_proj_dim = embedding_proj_dim or embedding_dim + clip_embed_dim = clip_embed_dim or embedding_dim + + self.time_proj = Timesteps(inner_dim, True, 0) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) + + self.proj_in = nn.Linear(embedding_dim, inner_dim) + + if embedding_proj_norm_type is None: + self.embedding_proj_norm = None + elif embedding_proj_norm_type == "layer": + self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim) + else: + raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") + + self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim) + + if encoder_hid_proj_type is None: + self.encoder_hidden_states_proj = None + elif encoder_hid_proj_type == "linear": + self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) + else: + raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") + + self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) + + if added_emb_type == "prd": + self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) + elif added_emb_type is None: + self.prd_embedding = None + else: + raise ValueError( + f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + activation_fn="gelu", + attention_bias=True, + ) + for d in range(num_layers) + ] + ) + + if norm_in_type == "layer": + self.norm_in = nn.LayerNorm(inner_dim) + elif norm_in_type is None: + self.norm_in = None + else: + raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") + + self.norm_out = nn.LayerNorm(inner_dim) + + self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim) + + causal_attention_mask = torch.full( + [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 + ) + causal_attention_mask.triu_(1) + causal_attention_mask = causal_attention_mask[None, ...] + self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) + + self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) + self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def forward( + self, + hidden_states, + timestep: Union[torch.Tensor, float, int], + proj_embedding: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + return_dict: bool = True, + ): + """ + The [`PriorTransformer`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + The currently predicted image embeddings. + timestep (`torch.LongTensor`): + Current denoising step. + proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + Projected embedding vector the denoising process is conditioned on. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): + Hidden states of the text embeddings the denoising process is conditioned on. + attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): + Text mask for the text embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain + tuple. + + Returns: + [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: + If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + batch_size = hidden_states.shape[0] + + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) + + timesteps_projected = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timesteps_projected = timesteps_projected.to(dtype=self.dtype) + time_embeddings = self.time_embedding(timesteps_projected) + + if self.embedding_proj_norm is not None: + proj_embedding = self.embedding_proj_norm(proj_embedding) + + proj_embeddings = self.embedding_proj(proj_embedding) + if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: + encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) + elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") + + hidden_states = self.proj_in(hidden_states) + + positional_embeddings = self.positional_embedding.to(hidden_states.dtype) + + additional_embeds = [] + additional_embeddings_len = 0 + + if encoder_hidden_states is not None: + additional_embeds.append(encoder_hidden_states) + additional_embeddings_len += encoder_hidden_states.shape[1] + + if len(proj_embeddings.shape) == 2: + proj_embeddings = proj_embeddings[:, None, :] + + if len(hidden_states.shape) == 2: + hidden_states = hidden_states[:, None, :] + + additional_embeds = additional_embeds + [ + proj_embeddings, + time_embeddings[:, None, :], + hidden_states, + ] + + if self.prd_embedding is not None: + prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) + additional_embeds.append(prd_embedding) + + hidden_states = torch.cat( + additional_embeds, + dim=1, + ) + + # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens + additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 + if positional_embeddings.shape[1] < hidden_states.shape[1]: + positional_embeddings = F.pad( + positional_embeddings, + ( + 0, + 0, + additional_embeddings_len, + self.prd_embedding.shape[1] if self.prd_embedding is not None else 0, + ), + value=0.0, + ) + + hidden_states = hidden_states + positional_embeddings + + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) + attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) + attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) + + if self.norm_in is not None: + hidden_states = self.norm_in(hidden_states) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask) + + hidden_states = self.norm_out(hidden_states) + + if self.prd_embedding is not None: + hidden_states = hidden_states[:, -1] + else: + hidden_states = hidden_states[:, additional_embeddings_len:] + + predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) + + if not return_dict: + return (predicted_image_embedding,) + + return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) + + def post_process_latents(self, prior_latents): + prior_latents = (prior_latents * self.clip_std) + self.clip_mean + return prior_latents diff --git a/src/diffusers/models/transformers/t5_film_transformer.py b/src/diffusers/models/transformers/t5_film_transformer.py new file mode 100644 index 000000000000..b2d735f76a59 --- /dev/null +++ b/src/diffusers/models/transformers/t5_film_transformer.py @@ -0,0 +1,438 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Tuple + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ..attention_processor import Attention +from ..embeddings import get_timestep_embedding +from ..modeling_utils import ModelMixin + + +class T5FilmDecoder(ModelMixin, ConfigMixin): + r""" + T5 style decoder with FiLM conditioning. + + Args: + input_dims (`int`, *optional*, defaults to `128`): + The number of input dimensions. + targets_length (`int`, *optional*, defaults to `256`): + The length of the targets. + d_model (`int`, *optional*, defaults to `768`): + Size of the input hidden states. + num_layers (`int`, *optional*, defaults to `12`): + The number of `DecoderLayer`'s to use. + num_heads (`int`, *optional*, defaults to `12`): + The number of attention heads to use. + d_kv (`int`, *optional*, defaults to `64`): + Size of the key-value projection vectors. + d_ff (`int`, *optional*, defaults to `2048`): + The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s. + dropout_rate (`float`, *optional*, defaults to `0.1`): + Dropout probability. + """ + + @register_to_config + def __init__( + self, + input_dims: int = 128, + targets_length: int = 256, + max_decoder_noise_time: float = 2000.0, + d_model: int = 768, + num_layers: int = 12, + num_heads: int = 12, + d_kv: int = 64, + d_ff: int = 2048, + dropout_rate: float = 0.1, + ): + super().__init__() + + self.conditioning_emb = nn.Sequential( + nn.Linear(d_model, d_model * 4, bias=False), + nn.SiLU(), + nn.Linear(d_model * 4, d_model * 4, bias=False), + nn.SiLU(), + ) + + self.position_encoding = nn.Embedding(targets_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False) + + self.dropout = nn.Dropout(p=dropout_rate) + + self.decoders = nn.ModuleList() + for lyr_num in range(num_layers): + # FiLM conditional T5 decoder + lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) + self.decoders.append(lyr) + + self.decoder_norm = T5LayerNorm(d_model) + + self.post_dropout = nn.Dropout(p=dropout_rate) + self.spec_out = nn.Linear(d_model, input_dims, bias=False) + + def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor: + mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) + return mask.unsqueeze(-3) + + def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): + batch, _, _ = decoder_input_tokens.shape + assert decoder_noise_time.shape == (batch,) + + # decoder_noise_time is in [0, 1), so rescale to expected timing range. + time_steps = get_timestep_embedding( + decoder_noise_time * self.config.max_decoder_noise_time, + embedding_dim=self.config.d_model, + max_period=self.config.max_decoder_noise_time, + ).to(dtype=self.dtype) + + conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) + + assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4) + + seq_length = decoder_input_tokens.shape[1] + + # If we want to use relative positions for audio context, we can just offset + # this sequence by the length of encodings_and_masks. + decoder_positions = torch.broadcast_to( + torch.arange(seq_length, device=decoder_input_tokens.device), + (batch, seq_length), + ) + + position_encodings = self.position_encoding(decoder_positions) + + inputs = self.continuous_inputs_projection(decoder_input_tokens) + inputs += position_encodings + y = self.dropout(inputs) + + # decoder: No padding present. + decoder_mask = torch.ones( + decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype + ) + + # Translate encoding masks to encoder-decoder masks. + encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] + + # cross attend style: concat encodings + encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1) + encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1) + + for lyr in self.decoders: + y = lyr( + y, + conditioning_emb=conditioning_emb, + encoder_hidden_states=encoded, + encoder_attention_mask=encoder_decoder_mask, + )[0] + + y = self.decoder_norm(y) + y = self.post_dropout(y) + + spec_out = self.spec_out(y) + return spec_out + + +class DecoderLayer(nn.Module): + r""" + T5 decoder layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__( + self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6 + ): + super().__init__() + self.layer = nn.ModuleList() + + # cond self attention: layer 0 + self.layer.append( + T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) + ) + + # cross attention: layer 1 + self.layer.append( + T5LayerCrossAttention( + d_model=d_model, + d_kv=d_kv, + num_heads=num_heads, + dropout_rate=dropout_rate, + layer_norm_epsilon=layer_norm_epsilon, + ) + ) + + # Film Cond MLP + dropout: last layer + self.layer.append( + T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + conditioning_emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + encoder_decoder_position_bias=None, + ) -> Tuple[torch.FloatTensor]: + hidden_states = self.layer[0]( + hidden_states, + conditioning_emb=conditioning_emb, + attention_mask=attention_mask, + ) + + if encoder_hidden_states is not None: + encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to( + encoder_hidden_states.dtype + ) + + hidden_states = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_extended_attention_mask, + ) + + # Apply Film Conditional Feed Forward layer + hidden_states = self.layer[-1](hidden_states, conditioning_emb) + + return (hidden_states,) + + +class T5LayerSelfAttentionCond(nn.Module): + r""" + T5 style self-attention layer with conditioning. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + dropout_rate (`float`): + Dropout probability. + """ + + def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float): + super().__init__() + self.layer_norm = T5LayerNorm(d_model) + self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + hidden_states: torch.FloatTensor, + conditioning_emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + # pre_self_attention_layer_norm + normed_hidden_states = self.layer_norm(hidden_states) + + if conditioning_emb is not None: + normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) + + # Self-attention block + attention_output = self.attention(normed_hidden_states) + + hidden_states = hidden_states + self.dropout(attention_output) + + return hidden_states + + +class T5LayerCrossAttention(nn.Module): + r""" + T5 style cross-attention layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float): + super().__init__() + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + hidden_states: torch.FloatTensor, + key_value_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.attention( + normed_hidden_states, + encoder_hidden_states=key_value_states, + attention_mask=attention_mask.squeeze(1), + ) + layer_output = hidden_states + self.dropout(attention_output) + return layer_output + + +class T5LayerFFCond(nn.Module): + r""" + T5 style feed-forward conditional layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) + self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + forwarded_states = self.layer_norm(hidden_states) + if conditioning_emb is not None: + forwarded_states = self.film(forwarded_states, conditioning_emb) + + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + r""" + T5 style feed-forward layer with gated activations and dropout. + + Args: + d_model (`int`): + Size of the input hidden states. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + """ + + def __init__(self, d_model: int, d_ff: int, dropout_rate: float): + super().__init__() + self.wi_0 = nn.Linear(d_model, d_ff, bias=False) + self.wi_1 = nn.Linear(d_model, d_ff, bias=False) + self.wo = nn.Linear(d_ff, d_model, bias=False) + self.dropout = nn.Dropout(dropout_rate) + self.act = NewGELUActivation() + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerNorm(nn.Module): + r""" + T5 style layer normalization module. + + Args: + hidden_size (`int`): + Size of the input hidden states. + eps (`float`, `optional`, defaults to `1e-6`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + +class T5FiLMLayer(nn.Module): + """ + T5 style FiLM Layer. + + Args: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. + """ + + def __init__(self, in_features: int, out_features: int): + super().__init__() + self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) + + def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor: + emb = self.scale_bias(conditioning_emb) + scale, shift = torch.chunk(emb, 2, -1) + x = x * (1 + scale) + shift + return x diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py new file mode 100644 index 000000000000..17bf8386c106 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -0,0 +1,458 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version +from ..attention import BasicTransformerBlock +from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection +from ..lora import LoRACompatibleConv, LoRACompatibleLinear +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + self.use_additional_conditions = self.config.sample_size == 128 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states = self.pos_embed(hidden_states) + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + else: + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if self.is_input_patches: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py new file mode 100644 index 000000000000..e5bc1226b4b5 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -0,0 +1,379 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..resnet import AlphaBlender + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + The output of [`TransformerTemporalModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. + """ + + sample: torch.FloatTensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.LongTensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: torch.LongTensor = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> TransformerTemporalModelOutput: + """ + The [`TransformerTemporal`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, num_frames, channel) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) + + +class TransformerSpatioTemporalModel(nn.Module): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + out_channels (`int`, *optional*): + The number of channels in the output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int = 320, + out_channels: Optional[int] = None, + num_layers: int = 1, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + + # 2. Define input layers + self.in_channels = in_channels + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for d in range(num_layers) + ] + ) + + time_mix_inner_dim = inner_dim + self.temporal_transformer_blocks = nn.ModuleList( + [ + TemporalBasicTransformerBlock( + inner_dim, + time_mix_inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + + time_embed_dim = in_channels * 4 + self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) + self.time_proj = Timesteps(in_channels, True, 0) + self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + # TODO: should use out_channels for continuous projections + self.proj_out = nn.Linear(inner_dim, in_channels) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input hidden_states. + num_frames (`int`): + The number of frames to be processed per batch. This is used to reshape the hidden states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*): + A tensor indicating whether the input contains only images. 1 indicates that the input contains only + images, 0 indicates that the input contains video frames. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, _, height, width = hidden_states.shape + num_frames = image_only_indicator.shape[-1] + batch_size = batch_frames // num_frames + + time_context = encoder_hidden_states + time_context_first_timestep = time_context[None, :].reshape( + batch_size, num_frames, -1, time_context.shape[-1] + )[:, 0] + time_context = time_context_first_timestep[None, :].broadcast_to( + height * width, batch_size, 1, time_context.shape[-1] + ) + time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1]) + + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + num_frames_emb = torch.arange(num_frames, device=hidden_states.device) + num_frames_emb = num_frames_emb.repeat(batch_size, 1) + num_frames_emb = num_frames_emb.reshape(-1) + t_emb = self.time_proj(num_frames_emb) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + # 2. Blocks + for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + None, + encoder_hidden_states, + None, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states_mix = hidden_states + hidden_states_mix = hidden_states_mix + emb + + hidden_states_mix = temporal_block( + hidden_states_mix, + num_frames=num_frames, + encoder_hidden_states=time_context, + ) + hidden_states = self.time_mixer( + x_spatial=hidden_states, + x_temporal=hidden_states_mix, + image_only_indicator=image_only_indicator, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index d933691d89d3..3796896ef675 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -22,7 +22,6 @@ from ...utils.torch_utils import apply_freeu from ..activations import get_activation from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 -from ..dual_transformer_2d import DualTransformer2DModel from ..normalization import AdaGroupNorm from ..resnet import ( Downsample2D, @@ -34,7 +33,8 @@ ResnetBlockCondNorm2D, Upsample2D, ) -from ..transformer_2d import Transformer2DModel +from ..transformers.dual_transformer_2d import DualTransformer2DModel +from ..transformers.transformer_2d import Transformer2DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 6c20b1175349..a1d9e848c230 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -20,7 +20,6 @@ from ...utils import is_torch_version from ...utils.torch_utils import apply_freeu from ..attention import Attention -from ..dual_transformer_2d import DualTransformer2DModel from ..resnet import ( Downsample2D, ResnetBlock2D, @@ -28,8 +27,9 @@ TemporalConvLayer, Upsample2D, ) -from ..transformer_2d import Transformer2DModel -from ..transformer_temporal import ( +from ..transformers.dual_transformer_2d import DualTransformer2DModel +from ..transformers.transformer_2d import Transformer2DModel +from ..transformers.transformer_temporal import ( TransformerSpatioTemporalModel, TransformerTemporalModel, ) diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index b29e2c270ba9..c28fd58222da 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -33,7 +33,7 @@ ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin -from ..transformer_temporal import TransformerTemporalModel +from ..transformers.transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 9654ae508215..9166cda82369 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -29,7 +29,7 @@ ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin -from ..transformer_temporal import TransformerTemporalModel +from ..transformers.transformer_temporal import TransformerTemporalModel from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_condition import UNet2DConditionModel from .unet_3d_blocks import ( diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index 147dd7a58e7b..b8a37696aa6c 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -35,7 +35,7 @@ ) from ...models.modeling_utils import ModelMixin from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D -from ...models.transformer_2d import Transformer2DModel +from ...models.transformers.transformer_2d import Transformer2DModel from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D from ...models.unets.unet_2d_condition import UNet2DConditionOutput from ...utils import BaseOutput, is_torch_version, logging diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 20884a15da4d..e772d8be2a45 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -19,7 +19,6 @@ AttnAddedKVProcessor2_0, AttnProcessor, ) -from ....models.dual_transformer_2d import DualTransformer2DModel from ....models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, @@ -32,7 +31,8 @@ Timesteps, ) from ....models.resnet import ResnetBlockCondNorm2D -from ....models.transformer_2d import Transformer2DModel +from ....models.transformers.dual_transformer_2d import DualTransformer2DModel +from ....models.transformers.transformer_2d import Transformer2DModel from ....models.unets.unet_2d_condition import UNet2DConditionOutput from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ....utils.torch_utils import apply_freeu diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py index 561d8344e746..c074b9916301 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -10,7 +10,7 @@ from ...models.attention_processor import Attention from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed from ...models.normalization import AdaLayerNorm -from ...models.transformer_2d import Transformer2DModelOutput +from ...models.transformers.transformer_2d import Transformer2DModelOutput from ...utils import logging diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index c6e3e19d4cc3..bc4c4a606f95 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -24,7 +24,7 @@ from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.lora import LoRACompatibleLinear from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D -from diffusers.models.transformer_2d import Transformer2DModel +from diffusers.models.transformers.transformer_2d import Transformer2DModel from diffusers.utils.testing_utils import ( backend_manual_seed, require_torch_accelerator_with_fp64, diff --git a/tests/models/transformers/__init__.py b/tests/models/transformers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/test_models_prior.py b/tests/models/transformers/test_models_prior.py similarity index 99% rename from tests/models/test_models_prior.py rename to tests/models/transformers/test_models_prior.py index 896a75de6f1b..c014220d8006 100644 --- a/tests/models/test_models_prior.py +++ b/tests/models/transformers/test_models_prior.py @@ -30,7 +30,7 @@ torch_device, ) -from .test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin enable_full_determinism()