Skip to content

Commit 55412f3

Browse files
committed
unet time embedding activation function
1 parent fbc9a73 commit 55412f3

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818
import torch.nn as nn
19+
import torch.nn.functional as F
1920
import torch.utils.checkpoint
2021

2122
from ..configuration_utils import ConfigMixin, register_to_config
@@ -101,6 +102,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
101102
class conditioning with `class_embed_type` equal to `None`.
102103
time_embedding_type (`str`, *optional*, default to `positional`):
103104
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
105+
time_embedding_act_fn (`str`, *optional*, default to `None`):
106+
Optional activation function to use on the time embeddings only one time before they as passed to the rest
107+
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
104108
timestep_post_act (`str, *optional*, default to `None`):
105109
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
106110
time_cond_proj_dim (`int`, *optional*, default to `None`):
@@ -152,6 +156,7 @@ def __init__(
152156
resnet_skip_time_act: bool = False,
153157
resnet_out_scale_factor: int = 1.0,
154158
time_embedding_type: str = "positional",
159+
time_embedding_act_fn: Optional[str] = None,
155160
timestep_post_act: Optional[str] = None,
156161
time_cond_proj_dim: Optional[int] = None,
157162
conv_in_kernel: int = 3,
@@ -261,6 +266,20 @@ def __init__(
261266
else:
262267
self.class_embedding = None
263268

269+
if time_embedding_act_fn is not None:
270+
if act_fn == "swish":
271+
self.time_embed_act = lambda x: F.silu(x)
272+
elif act_fn == "mish":
273+
self.time_embed_act = nn.Mish()
274+
elif act_fn == "silu":
275+
self.time_embed_act = nn.SiLU()
276+
elif act_fn == "gelu":
277+
self.time_embed_act = nn.GELU()
278+
else:
279+
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
280+
else:
281+
self.time_embed_act = None
282+
264283
self.down_blocks = nn.ModuleList([])
265284
self.up_blocks = nn.ModuleList([])
266285

@@ -634,6 +653,9 @@ def forward(
634653
else:
635654
emb = emb + class_emb
636655

656+
if self.time_embed_act is not None:
657+
emb = self.time_embed_act(emb)
658+
637659
if self.encoder_hid_proj is not None:
638660
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
639661

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
182182
class conditioning with `class_embed_type` equal to `None`.
183183
time_embedding_type (`str`, *optional*, default to `positional`):
184184
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
185+
time_embedding_act_fn (`str`, *optional*, default to `None`):
186+
Optional activation function to use on the time embeddings only one time before they as passed to the rest
187+
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
185188
timestep_post_act (`str, *optional*, default to `None`):
186189
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
187190
time_cond_proj_dim (`int`, *optional*, default to `None`):
@@ -238,6 +241,7 @@ def __init__(
238241
resnet_skip_time_act: bool = False,
239242
resnet_out_scale_factor: int = 1.0,
240243
time_embedding_type: str = "positional",
244+
time_embedding_act_fn: Optional[str] = None,
241245
timestep_post_act: Optional[str] = None,
242246
time_cond_proj_dim: Optional[int] = None,
243247
conv_in_kernel: int = 3,
@@ -353,6 +357,20 @@ def __init__(
353357
else:
354358
self.class_embedding = None
355359

360+
if time_embedding_act_fn is not None:
361+
if act_fn == "swish":
362+
self.time_embed_act = lambda x: F.silu(x)
363+
elif act_fn == "mish":
364+
self.time_embed_act = nn.Mish()
365+
elif act_fn == "silu":
366+
self.time_embed_act = nn.SiLU()
367+
elif act_fn == "gelu":
368+
self.time_embed_act = nn.GELU()
369+
else:
370+
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
371+
else:
372+
self.time_embed_act = None
373+
356374
self.down_blocks = nn.ModuleList([])
357375
self.up_blocks = nn.ModuleList([])
358376

@@ -726,6 +744,9 @@ def forward(
726744
else:
727745
emb = emb + class_emb
728746

747+
if self.time_embed_act is not None:
748+
emb = self.time_embed_act(emb)
749+
729750
if self.encoder_hid_proj is not None:
730751
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
731752

0 commit comments

Comments
 (0)