diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 3fb4202ed119..9243dc66d3e8 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config @@ -101,6 +102,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, default to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_act_fn (`str`, *optional*, default to `None`): + Optional activation function to use on the time embeddings only one time before they as passed to the rest + of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str, *optional*, default to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, default to `None`): @@ -157,6 +161,7 @@ def __init__( resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", + time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, @@ -267,6 +272,19 @@ def __init__( else: self.class_embedding = None + if time_embedding_act_fn is None: + self.time_embed_act = None + elif time_embedding_act_fn == "swish": + self.time_embed_act = lambda x: F.silu(x) + elif time_embedding_act_fn == "mish": + self.time_embed_act = nn.Mish() + elif time_embedding_act_fn == "silu": + self.time_embed_act = nn.SiLU() + elif time_embedding_act_fn == "gelu": + self.time_embed_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") + self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -657,6 +675,9 @@ def forward( else: emb = emb + class_emb + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + if self.encoder_hid_proj is not None: encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 51d1c62c926b..cc8cde4daa3b 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -3,6 +3,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin @@ -182,6 +183,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, default to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_act_fn (`str`, *optional*, default to `None`): + Optional activation function to use on the time embeddings only one time before they as passed to the rest + of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str, *optional*, default to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, default to `None`): @@ -243,6 +247,7 @@ def __init__( resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", + time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, @@ -359,6 +364,19 @@ def __init__( else: self.class_embedding = None + if time_embedding_act_fn is None: + self.time_embed_act = None + elif time_embedding_act_fn == "swish": + self.time_embed_act = lambda x: F.silu(x) + elif time_embedding_act_fn == "mish": + self.time_embed_act = nn.Mish() + elif time_embedding_act_fn == "silu": + self.time_embed_act = nn.SiLU() + elif time_embedding_act_fn == "gelu": + self.time_embed_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") + self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -752,6 +770,9 @@ def forward( else: emb = emb + class_emb + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + if self.encoder_hid_proj is not None: encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)