From 957ed7a03a84fc49c2a9d0d45b9f49e6f5aaaa6d Mon Sep 17 00:00:00 2001 From: William Berman Date: Mon, 17 Apr 2023 15:44:29 -0700 Subject: [PATCH] Add unet act fn to other model components Adding act fn config to the unet timestep class embedding and conv activation. The custom activation defaults to silu which is the default activation function for both the conv act and the timestep class embeddings so default behavior is not changed. The only unet which use the custom activation is the stable diffusion latent upscaler https://huggingface.co/stabilityai/sd-x2-latent-upscaler/blob/main/unet/config.json (I ran a script against the hub to confirm). The latent upscaler does not use the conv activation nor the timestep class embeddings so we don't change its behavior. --- src/diffusers/models/unet_2d_condition.py | 15 +++++++++++++-- .../versatile_diffusion/modeling_text_unet.py | 15 +++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index b2814356939b..29de8734d4e7 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -248,7 +248,7 @@ def __init__( if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": @@ -437,7 +437,18 @@ def __init__( self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - self.conv_act = nn.SiLU() + + if act_fn == "swish": + self.conv_act = lambda x: F.silu(x) + elif act_fn == "mish": + self.conv_act = nn.Mish() + elif act_fn == "silu": + self.conv_act = nn.SiLU() + elif act_fn == "gelu": + self.conv_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + else: self.conv_norm_out = None self.conv_act = None diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 4377be1181a8..b20f18c485d0 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -345,7 +345,7 @@ def __init__( if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": @@ -534,7 +534,18 @@ def __init__( self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - self.conv_act = nn.SiLU() + + if act_fn == "swish": + self.conv_act = lambda x: F.silu(x) + elif act_fn == "mish": + self.conv_act = nn.Mish() + elif act_fn == "silu": + self.conv_act = nn.SiLU() + elif act_fn == "gelu": + self.conv_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + else: self.conv_norm_out = None self.conv_act = None