From 25f5342222ef71089c4ecea5cdad4bd68d83f32f Mon Sep 17 00:00:00 2001 From: hypnopump Date: Mon, 1 May 2023 12:30:20 +0200 Subject: [PATCH 1/4] wider support for gelu. use same torch layer for silu and swish --- src/diffusers/models/attention.py | 6 ++--- src/diffusers/models/resnet.py | 6 ++--- src/diffusers/models/unet_1d_blocks.py | 24 +++++++++---------- src/diffusers/models/unet_2d_condition.py | 12 ++++------ .../versatile_diffusion/modeling_text_unet.py | 12 ++++------ 5 files changed, 24 insertions(+), 36 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index fb5f6f48b324..048a84243c65 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -518,12 +518,10 @@ def __init__( self.num_groups = num_groups self.eps = eps self.act = None - if act_fn == "swish": - self.act = lambda x: F.silu(x) + if act_fn in {"swish", "silu"}: + self.act = nn.SiLU() elif act_fn == "mish": self.act = nn.Mish() - elif act_fn == "silu": - self.act = nn.SiLU() elif act_fn == "gelu": self.act = nn.GELU() diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d9d539959c09..3106996b4007 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -513,12 +513,10 @@ def __init__( conv_2d_out_channels = conv_2d_out_channels or out_channels self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) + if non_linearity in {"swish", "silu"}: + self.nonlinearity = nn.SiLU() elif non_linearity == "mish": self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() elif non_linearity == "gelu": self.nonlinearity = nn.GELU() diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index a0f0e58f9103..a767bfeba0be 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -55,12 +55,12 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) + if non_linearity in {"swish", "silu"}: + self.nonlinearity = nn.SiLU() elif non_linearity == "mish": self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() else: self.nonlinearity = None @@ -119,12 +119,12 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) + if non_linearity in {"swish", "silu"}: + self.nonlinearity = nn.SiLU() elif non_linearity == "mish": self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() else: self.nonlinearity = None @@ -194,12 +194,12 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) + if non_linearity in {"swish", "silu"}: + self.nonlinearity = nn.SiLU() elif non_linearity == "mish": self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() else: self.nonlinearity = None diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 38e0fa3b5b2e..23bb5b2b1ac0 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -295,12 +295,10 @@ def __init__( if time_embedding_act_fn is None: self.time_embed_act = None - elif time_embedding_act_fn == "swish": - self.time_embed_act = lambda x: F.silu(x) + elif time_embedding_act_fn in {"swish", "silu"}: + self.time_embed_act = nn.SiLU() elif time_embedding_act_fn == "mish": self.time_embed_act = nn.Mish() - elif time_embedding_act_fn == "silu": - self.time_embed_act = nn.SiLU() elif time_embedding_act_fn == "gelu": self.time_embed_act = nn.GELU() else: @@ -458,12 +456,10 @@ def __init__( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - if act_fn == "swish": - self.conv_act = lambda x: F.silu(x) + if act_fn in {"swish", "silu"}: + self.conv_act = nn.SiLU() elif act_fn == "mish": self.conv_act = nn.Mish() - elif act_fn == "silu": - self.conv_act = nn.SiLU() elif act_fn == "gelu": self.conv_act = nn.GELU() else: diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 0959e2bb3a8b..fb9f27abc35e 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -398,12 +398,10 @@ def __init__( if time_embedding_act_fn is None: self.time_embed_act = None - elif time_embedding_act_fn == "swish": - self.time_embed_act = lambda x: F.silu(x) + elif time_embedding_act_fn in {"swish", "silu"}: + self.time_embed_act = nn.SiLU() elif time_embedding_act_fn == "mish": self.time_embed_act = nn.Mish() - elif time_embedding_act_fn == "silu": - self.time_embed_act = nn.SiLU() elif time_embedding_act_fn == "gelu": self.time_embed_act = nn.GELU() else: @@ -561,12 +559,10 @@ def __init__( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - if act_fn == "swish": - self.conv_act = lambda x: F.silu(x) + if act_fn in {"swish", "silu"}: + self.conv_act = nn.SiLU() elif act_fn == "mish": self.conv_act = nn.Mish() - elif act_fn == "silu": - self.conv_act = nn.SiLU() elif act_fn == "gelu": self.conv_act = nn.GELU() else: From 937ee188d9ed4ff37c24edbe87d38ebd54fe6bff Mon Sep 17 00:00:00 2001 From: hypnopump Date: Mon, 1 May 2023 12:37:49 +0200 Subject: [PATCH 2/4] remove useless import --- src/diffusers/models/unet_2d_condition.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 23bb5b2b1ac0..39ad205a67c0 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -16,7 +16,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config From 67ff580412e4d4c3de9acc4f231626453f035d61 Mon Sep 17 00:00:00 2001 From: hypnopump Date: Mon, 1 May 2023 12:40:15 +0200 Subject: [PATCH 3/4] remove unused activation --- src/diffusers/models/unet_1d_blocks.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index a767bfeba0be..03965e6f8790 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -179,7 +179,6 @@ def __init__( num_layers: int = 1, add_downsample: bool = False, add_upsample: bool = False, - non_linearity=None, ): super().__init__() self.in_channels = in_channels @@ -194,15 +193,6 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity in {"swish", "silu"}: - self.nonlinearity = nn.SiLU() - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "gelu": - self.nonlinearity = nn.GELU() - else: - self.nonlinearity = None - self.upsample = None if add_upsample: self.upsample = Downsample1D(out_channels, use_conv=True) From 1f26ebe43aa12665aa35bfc5901a604941649461 Mon Sep 17 00:00:00 2001 From: hypnopump Date: Mon, 1 May 2023 12:42:46 +0200 Subject: [PATCH 4/4] raise error in construction if param is needed at runtime. --- src/diffusers/models/unet_1d_blocks.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 03965e6f8790..dc6f45399fe7 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -222,10 +222,14 @@ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): super().__init__() self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) - if act_fn == "silu": + if act_fn in {"silu", "swish"}: self.final_conv1d_act = nn.SiLU() - if act_fn == "mish": + elif act_fn == "mish": self.final_conv1d_act = nn.Mish() + elif act_fn == "gelu": + self.final_conv1d_act = nn.GELU() + else: + raise ValueError(f"Act_fn {act_fn} must be one of silu, mish, gelu") self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) def forward(self, hidden_states, temb=None):