Skip to content

Commit df4eb1b

Browse files
committed
unet time embedding activation function
1 parent fbc9a73 commit df4eb1b

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818
import torch.nn as nn
19+
import torch.nn.functional as F
1920
import torch.utils.checkpoint
2021

2122
from ..configuration_utils import ConfigMixin, register_to_config
@@ -101,6 +102,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
101102
class conditioning with `class_embed_type` equal to `None`.
102103
time_embedding_type (`str`, *optional*, default to `positional`):
103104
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
105+
time_embedding_act_fn (`str`, *optional*, default to `None`):
106+
Optional activation function to use on the time embeddings only one time before they as passed to the rest
107+
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
104108
timestep_post_act (`str, *optional*, default to `None`):
105109
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
106110
time_cond_proj_dim (`int`, *optional*, default to `None`):
@@ -152,6 +156,7 @@ def __init__(
152156
resnet_skip_time_act: bool = False,
153157
resnet_out_scale_factor: int = 1.0,
154158
time_embedding_type: str = "positional",
159+
time_embedding_act_fn: Optional[str] = None,
155160
timestep_post_act: Optional[str] = None,
156161
time_cond_proj_dim: Optional[int] = None,
157162
conv_in_kernel: int = 3,
@@ -261,6 +266,20 @@ def __init__(
261266
else:
262267
self.class_embedding = None
263268

269+
if time_embedding_act_fn is not None:
270+
if act_fn == "swish":
271+
self.time_embed_act = lambda x: F.silu(x)
272+
elif act_fn == "mish":
273+
self.time_embed_act = nn.Mish()
274+
elif act_fn == "silu":
275+
self.time_embed_act = nn.SiLU()
276+
elif act_fn == "gelu":
277+
self.time_embed_act = nn.GELU()
278+
else:
279+
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
280+
else:
281+
self.time_embed_act = None
282+
264283
self.down_blocks = nn.ModuleList([])
265284
self.up_blocks = nn.ModuleList([])
266285

@@ -634,6 +653,9 @@ def forward(
634653
else:
635654
emb = emb + class_emb
636655

656+
if self.time_embed_act is not None:
657+
emb = self.time_embed_act(emb)
658+
637659
if self.encoder_hid_proj is not None:
638660
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
639661

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import torch
55
import torch.nn as nn
6+
import torch.nn.functional as F
67

78
from ...configuration_utils import ConfigMixin, register_to_config
89
from ...models import ModelMixin
@@ -182,6 +183,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
182183
class conditioning with `class_embed_type` equal to `None`.
183184
time_embedding_type (`str`, *optional*, default to `positional`):
184185
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
186+
time_embedding_act_fn (`str`, *optional*, default to `None`):
187+
Optional activation function to use on the time embeddings only one time before they as passed to the rest
188+
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
185189
timestep_post_act (`str, *optional*, default to `None`):
186190
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
187191
time_cond_proj_dim (`int`, *optional*, default to `None`):
@@ -238,6 +242,7 @@ def __init__(
238242
resnet_skip_time_act: bool = False,
239243
resnet_out_scale_factor: int = 1.0,
240244
time_embedding_type: str = "positional",
245+
time_embedding_act_fn: Optional[str] = None,
241246
timestep_post_act: Optional[str] = None,
242247
time_cond_proj_dim: Optional[int] = None,
243248
conv_in_kernel: int = 3,
@@ -353,6 +358,20 @@ def __init__(
353358
else:
354359
self.class_embedding = None
355360

361+
if time_embedding_act_fn is not None:
362+
if act_fn == "swish":
363+
self.time_embed_act = lambda x: F.silu(x)
364+
elif act_fn == "mish":
365+
self.time_embed_act = nn.Mish()
366+
elif act_fn == "silu":
367+
self.time_embed_act = nn.SiLU()
368+
elif act_fn == "gelu":
369+
self.time_embed_act = nn.GELU()
370+
else:
371+
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
372+
else:
373+
self.time_embed_act = None
374+
356375
self.down_blocks = nn.ModuleList([])
357376
self.up_blocks = nn.ModuleList([])
358377

@@ -726,6 +745,9 @@ def forward(
726745
else:
727746
emb = emb + class_emb
728747

748+
if self.time_embed_act is not None:
749+
emb = self.time_embed_act(emb)
750+
729751
if self.encoder_hid_proj is not None:
730752
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
731753

0 commit comments

Comments
 (0)