From 23569b4397fb39087850e4eaad5388a33ccafc64 Mon Sep 17 00:00:00 2001 From: William Berman Date: Mon, 10 Apr 2023 18:25:48 -0700 Subject: [PATCH 1/3] unet time embedding activation function --- src/diffusers/models/unet_2d_condition.py | 22 +++++++++++++++++++ .../versatile_diffusion/modeling_text_unet.py | 22 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 3fb4202ed119..b5d162bec591 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config @@ -101,6 +102,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, default to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_act_fn (`str`, *optional*, default to `None`): + Optional activation function to use on the time embeddings only one time before they as passed to the rest + of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str, *optional*, default to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, default to `None`): @@ -157,6 +161,7 @@ def __init__( resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", + time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, @@ -267,6 +272,20 @@ def __init__( else: self.class_embedding = None + if time_embedding_act_fn is not None: + if act_fn == "swish": + self.time_embed_act = lambda x: F.silu(x) + elif act_fn == "mish": + self.time_embed_act = nn.Mish() + elif act_fn == "silu": + self.time_embed_act = nn.SiLU() + elif act_fn == "gelu": + self.time_embed_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") + else: + self.time_embed_act = None + self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -657,6 +676,9 @@ def forward( else: emb = emb + class_emb + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + if self.encoder_hid_proj is not None: encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 51d1c62c926b..9a0127562ea3 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -3,6 +3,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin @@ -182,6 +183,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, default to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_act_fn (`str`, *optional*, default to `None`): + Optional activation function to use on the time embeddings only one time before they as passed to the rest + of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str, *optional*, default to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, default to `None`): @@ -243,6 +247,7 @@ def __init__( resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", + time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, @@ -359,6 +364,20 @@ def __init__( else: self.class_embedding = None + if time_embedding_act_fn is not None: + if act_fn == "swish": + self.time_embed_act = lambda x: F.silu(x) + elif act_fn == "mish": + self.time_embed_act = nn.Mish() + elif act_fn == "silu": + self.time_embed_act = nn.SiLU() + elif act_fn == "gelu": + self.time_embed_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") + else: + self.time_embed_act = None + self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -752,6 +771,9 @@ def forward( else: emb = emb + class_emb + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + if self.encoder_hid_proj is not None: encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) From fcc4c817725d307358955c6f5db8bb1a06178525 Mon Sep 17 00:00:00 2001 From: William Berman Date: Mon, 10 Apr 2023 19:45:57 -0700 Subject: [PATCH 2/3] typo act_fn -> time_embedding_act_fn --- src/diffusers/models/unet_2d_condition.py | 8 ++++---- .../pipelines/versatile_diffusion/modeling_text_unet.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index b5d162bec591..b359c0d3385b 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -273,13 +273,13 @@ def __init__( self.class_embedding = None if time_embedding_act_fn is not None: - if act_fn == "swish": + if time_embedding_act_fn == "swish": self.time_embed_act = lambda x: F.silu(x) - elif act_fn == "mish": + elif time_embedding_act_fn == "mish": self.time_embed_act = nn.Mish() - elif act_fn == "silu": + elif time_embedding_act_fn == "silu": self.time_embed_act = nn.SiLU() - elif act_fn == "gelu": + elif time_embedding_act_fn == "gelu": self.time_embed_act = nn.GELU() else: raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 9a0127562ea3..8dbd7ff6d549 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -365,13 +365,13 @@ def __init__( self.class_embedding = None if time_embedding_act_fn is not None: - if act_fn == "swish": + if time_embedding_act_fn == "swish": self.time_embed_act = lambda x: F.silu(x) - elif act_fn == "mish": + elif time_embedding_act_fn == "mish": self.time_embed_act = nn.Mish() - elif act_fn == "silu": + elif time_embedding_act_fn == "silu": self.time_embed_act = nn.SiLU() - elif act_fn == "gelu": + elif time_embedding_act_fn == "gelu": self.time_embed_act = nn.GELU() else: raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") From e309542fe59cd4e3203a5d17c1d57064cca6ec30 Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 11 Apr 2023 10:54:14 -0700 Subject: [PATCH 3/3] flatten conditional --- src/diffusers/models/unet_2d_condition.py | 23 +++++++++---------- .../versatile_diffusion/modeling_text_unet.py | 23 +++++++++---------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index b359c0d3385b..9243dc66d3e8 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -272,19 +272,18 @@ def __init__( else: self.class_embedding = None - if time_embedding_act_fn is not None: - if time_embedding_act_fn == "swish": - self.time_embed_act = lambda x: F.silu(x) - elif time_embedding_act_fn == "mish": - self.time_embed_act = nn.Mish() - elif time_embedding_act_fn == "silu": - self.time_embed_act = nn.SiLU() - elif time_embedding_act_fn == "gelu": - self.time_embed_act = nn.GELU() - else: - raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") - else: + if time_embedding_act_fn is None: self.time_embed_act = None + elif time_embedding_act_fn == "swish": + self.time_embed_act = lambda x: F.silu(x) + elif time_embedding_act_fn == "mish": + self.time_embed_act = nn.Mish() + elif time_embedding_act_fn == "silu": + self.time_embed_act = nn.SiLU() + elif time_embedding_act_fn == "gelu": + self.time_embed_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 8dbd7ff6d549..cc8cde4daa3b 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -364,19 +364,18 @@ def __init__( else: self.class_embedding = None - if time_embedding_act_fn is not None: - if time_embedding_act_fn == "swish": - self.time_embed_act = lambda x: F.silu(x) - elif time_embedding_act_fn == "mish": - self.time_embed_act = nn.Mish() - elif time_embedding_act_fn == "silu": - self.time_embed_act = nn.SiLU() - elif time_embedding_act_fn == "gelu": - self.time_embed_act = nn.GELU() - else: - raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") - else: + if time_embedding_act_fn is None: self.time_embed_act = None + elif time_embedding_act_fn == "swish": + self.time_embed_act = lambda x: F.silu(x) + elif time_embedding_act_fn == "mish": + self.time_embed_act = nn.Mish() + elif time_embedding_act_fn == "silu": + self.time_embed_act = nn.SiLU() + elif time_embedding_act_fn == "gelu": + self.time_embed_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])