Skip to content

unet time embedding activation function #3048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint

from ..configuration_utils import ConfigMixin, register_to_config
Expand Down Expand Up @@ -101,6 +102,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_act_fn (`str`, *optional*, default to `None`):
Optional activation function to use on the time embeddings only one time before they as passed to the rest
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str, *optional*, default to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`):
Expand Down Expand Up @@ -157,6 +161,7 @@ def __init__(
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
Expand Down Expand Up @@ -267,6 +272,19 @@ def __init__(
else:
self.class_embedding = None

if time_embedding_act_fn is None:
self.time_embed_act = None
elif time_embedding_act_fn == "swish":
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == "mish":
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == "silu":
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == "gelu":
self.time_embed_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")

self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])

Expand Down Expand Up @@ -657,6 +675,9 @@ def forward(
else:
emb = emb + class_emb

if self.time_embed_act is not None:
emb = self.time_embed_act(emb)

Comment on lines +678 to +680
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional activation of time embeddings once at at the beginning of the unet

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity.

Is it being used in the private fork?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yessir!

if self.encoder_hid_proj is not None:
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)

Expand Down
21 changes: 21 additions & 0 deletions src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
Expand Down Expand Up @@ -182,6 +183,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_act_fn (`str`, *optional*, default to `None`):
Optional activation function to use on the time embeddings only one time before they as passed to the rest
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str, *optional*, default to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`):
Expand Down Expand Up @@ -243,6 +247,7 @@ def __init__(
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
Expand Down Expand Up @@ -359,6 +364,19 @@ def __init__(
else:
self.class_embedding = None

if time_embedding_act_fn is None:
self.time_embed_act = None
elif time_embedding_act_fn == "swish":
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == "mish":
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == "silu":
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == "gelu":
self.time_embed_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")

self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])

Expand Down Expand Up @@ -752,6 +770,9 @@ def forward(
else:
emb = emb + class_emb

if self.time_embed_act is not None:
emb = self.time_embed_act(emb)

if self.encoder_hid_proj is not None:
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)

Expand Down