diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 134f84fc9d50..740993d80c1b 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -520,12 +520,10 @@ def __init__( self.num_groups = num_groups self.eps = eps self.act = None - if act_fn == "swish": - self.act = lambda x: F.silu(x) + if act_fn in {"swish", "silu"}: + self.act = nn.SiLU() elif act_fn == "mish": self.act = nn.Mish() - elif act_fn == "silu": - self.act = nn.SiLU() elif act_fn == "gelu": self.act = nn.GELU() diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d9d539959c09..3106996b4007 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -513,12 +513,10 @@ def __init__( conv_2d_out_channels = conv_2d_out_channels or out_channels self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) + if non_linearity in {"swish", "silu"}: + self.nonlinearity = nn.SiLU() elif non_linearity == "mish": self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() elif non_linearity == "gelu": self.nonlinearity = nn.GELU() diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index a0f0e58f9103..dc6f45399fe7 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -55,12 +55,12 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) + if non_linearity in {"swish", "silu"}: + self.nonlinearity = nn.SiLU() elif non_linearity == "mish": self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() else: self.nonlinearity = None @@ -119,12 +119,12 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) + if non_linearity in {"swish", "silu"}: + self.nonlinearity = nn.SiLU() elif non_linearity == "mish": self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() else: self.nonlinearity = None @@ -179,7 +179,6 @@ def __init__( num_layers: int = 1, add_downsample: bool = False, add_upsample: bool = False, - non_linearity=None, ): super().__init__() self.in_channels = in_channels @@ -194,15 +193,6 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: - self.nonlinearity = None - self.upsample = None if add_upsample: self.upsample = Downsample1D(out_channels, use_conv=True) @@ -232,10 +222,14 @@ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): super().__init__() self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) - if act_fn == "silu": + if act_fn in {"silu", "swish"}: self.final_conv1d_act = nn.SiLU() - if act_fn == "mish": + elif act_fn == "mish": self.final_conv1d_act = nn.Mish() + elif act_fn == "gelu": + self.final_conv1d_act = nn.GELU() + else: + raise ValueError(f"Act_fn {act_fn} must be one of silu, mish, gelu") self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) def forward(self, hidden_states, temb=None): diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 2a4c9fd72c1b..5871f151d5f2 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -16,7 +16,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config @@ -295,12 +294,10 @@ def __init__( if time_embedding_act_fn is None: self.time_embed_act = None - elif time_embedding_act_fn == "swish": - self.time_embed_act = lambda x: F.silu(x) + elif time_embedding_act_fn in {"swish", "silu"}: + self.time_embed_act = nn.SiLU() elif time_embedding_act_fn == "mish": self.time_embed_act = nn.Mish() - elif time_embedding_act_fn == "silu": - self.time_embed_act = nn.SiLU() elif time_embedding_act_fn == "gelu": self.time_embed_act = nn.GELU() else: @@ -458,12 +455,10 @@ def __init__( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - if act_fn == "swish": - self.conv_act = lambda x: F.silu(x) + if act_fn in {"swish", "silu"}: + self.conv_act = nn.SiLU() elif act_fn == "mish": self.conv_act = nn.Mish() - elif act_fn == "silu": - self.conv_act = nn.SiLU() elif act_fn == "gelu": self.conv_act = nn.GELU() else: diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index f0a210339c46..31c186f6c4b8 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -398,12 +398,10 @@ def __init__( if time_embedding_act_fn is None: self.time_embed_act = None - elif time_embedding_act_fn == "swish": - self.time_embed_act = lambda x: F.silu(x) + elif time_embedding_act_fn in {"swish", "silu"}: + self.time_embed_act = nn.SiLU() elif time_embedding_act_fn == "mish": self.time_embed_act = nn.Mish() - elif time_embedding_act_fn == "silu": - self.time_embed_act = nn.SiLU() elif time_embedding_act_fn == "gelu": self.time_embed_act = nn.GELU() else: @@ -561,12 +559,10 @@ def __init__( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - if act_fn == "swish": - self.conv_act = lambda x: F.silu(x) + if act_fn in {"swish", "silu"}: + self.conv_act = nn.SiLU() elif act_fn == "mish": self.conv_act = nn.Mish() - elif act_fn == "silu": - self.conv_act = nn.SiLU() elif act_fn == "gelu": self.conv_act = nn.GELU() else: