|
16 | 16 |
|
17 | 17 | import torch
|
18 | 18 | import torch.nn as nn
|
| 19 | +import torch.nn.functional as F |
19 | 20 | import torch.utils.checkpoint
|
20 | 21 |
|
21 | 22 | from ..configuration_utils import ConfigMixin, register_to_config
|
@@ -101,6 +102,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
101 | 102 | class conditioning with `class_embed_type` equal to `None`.
|
102 | 103 | time_embedding_type (`str`, *optional*, default to `positional`):
|
103 | 104 | The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
| 105 | + time_embedding_act_fn (`str`, *optional*, default to `None`): |
| 106 | + Optional activation function to use on the time embeddings only one time before they as passed to the rest |
| 107 | + of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. |
104 | 108 | timestep_post_act (`str, *optional*, default to `None`):
|
105 | 109 | The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
106 | 110 | time_cond_proj_dim (`int`, *optional*, default to `None`):
|
@@ -152,6 +156,7 @@ def __init__(
|
152 | 156 | resnet_skip_time_act: bool = False,
|
153 | 157 | resnet_out_scale_factor: int = 1.0,
|
154 | 158 | time_embedding_type: str = "positional",
|
| 159 | + time_embedding_act_fn: Optional[str] = None, |
155 | 160 | timestep_post_act: Optional[str] = None,
|
156 | 161 | time_cond_proj_dim: Optional[int] = None,
|
157 | 162 | conv_in_kernel: int = 3,
|
@@ -261,6 +266,20 @@ def __init__(
|
261 | 266 | else:
|
262 | 267 | self.class_embedding = None
|
263 | 268 |
|
| 269 | + if time_embedding_act_fn is not None: |
| 270 | + if act_fn == "swish": |
| 271 | + self.time_embed_act = lambda x: F.silu(x) |
| 272 | + elif act_fn == "mish": |
| 273 | + self.time_embed_act = nn.Mish() |
| 274 | + elif act_fn == "silu": |
| 275 | + self.time_embed_act = nn.SiLU() |
| 276 | + elif act_fn == "gelu": |
| 277 | + self.time_embed_act = nn.GELU() |
| 278 | + else: |
| 279 | + raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") |
| 280 | + else: |
| 281 | + self.time_embed_act = None |
| 282 | + |
264 | 283 | self.down_blocks = nn.ModuleList([])
|
265 | 284 | self.up_blocks = nn.ModuleList([])
|
266 | 285 |
|
@@ -634,6 +653,9 @@ def forward(
|
634 | 653 | else:
|
635 | 654 | emb = emb + class_emb
|
636 | 655 |
|
| 656 | + if self.time_embed_act is not None: |
| 657 | + emb = self.time_embed_act(emb) |
| 658 | + |
637 | 659 | if self.encoder_hid_proj is not None:
|
638 | 660 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
639 | 661 |
|
|
0 commit comments