diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py new file mode 100644 index 000000000000..64759b706e2f --- /dev/null +++ b/src/diffusers/models/activations.py @@ -0,0 +1,12 @@ +from torch import nn + + +def get_activation(act_fn): + if act_fn in ["swish", "silu"]: + return nn.SiLU() + elif act_fn == "mish": + return nn.Mish() + elif act_fn == "gelu": + return nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {act_fn}") diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index a7a9a472d9e9..8805257ebe9a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -18,6 +18,7 @@ from torch import nn from ..utils import maybe_allow_in_graph +from .activations import get_activation from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings @@ -345,15 +346,11 @@ def __init__( super().__init__() self.num_groups = num_groups self.eps = eps - self.act = None - if act_fn == "swish": - self.act = lambda x: F.silu(x) - elif act_fn == "mish": - self.act = nn.Mish() - elif act_fn == "silu": - self.act = nn.SiLU() - elif act_fn == "gelu": - self.act = nn.GELU() + + if act_fn is None: + self.act = None + else: + self.act = get_activation(act_fn) self.linear = nn.Linear(embedding_dim, out_dim * 2) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 991264a9aa8f..4dd16f0dd5ff 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -18,6 +18,8 @@ import torch from torch import nn +from .activations import get_activation + def get_timestep_embedding( timesteps: torch.Tensor, @@ -171,14 +173,7 @@ def __init__( else: self.cond_proj = None - if act_fn == "silu": - self.act = nn.SiLU() - elif act_fn == "mish": - self.act = nn.Mish() - elif act_fn == "gelu": - self.act = nn.GELU() - else: - raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") + self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim @@ -188,14 +183,8 @@ def __init__( if post_act_fn is None: self.post_act = None - elif post_act_fn == "silu": - self.post_act = nn.SiLU() - elif post_act_fn == "mish": - self.post_act = nn.Mish() - elif post_act_fn == "gelu": - self.post_act = nn.GELU() else: - raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") + self.post_act = get_activation(post_act_fn) def forward(self, sample, condition=None): if condition is not None: diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 3380a4909372..52f01552c528 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -20,6 +20,7 @@ import torch.nn as nn import torch.nn.functional as F +from .activations import get_activation from .attention import AdaGroupNorm from .attention_processor import SpatialNorm @@ -558,14 +559,7 @@ def __init__( conv_2d_out_channels = conv_2d_out_channels or out_channels self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - elif non_linearity == "gelu": - self.nonlinearity = nn.GELU() + self.nonlinearity = get_activation(non_linearity) self.upsample = self.downsample = None if self.up: @@ -646,11 +640,6 @@ def forward(self, input_tensor, temb): return output_tensor -class Mish(torch.nn.Module): - def forward(self, hidden_states): - return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) - - # unet_rl.py def rearrange_dims(tensor): if len(tensor.shape) == 2: diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 934a4a4a7dcb..3c04bffeeacc 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -17,6 +17,7 @@ import torch.nn.functional as F from torch import nn +from .activations import get_activation from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims @@ -55,14 +56,10 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: + if non_linearity is None: self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) self.downsample = None if add_downsample: @@ -119,14 +116,10 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: + if non_linearity is None: self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) self.upsample = None if add_upsample: @@ -194,14 +187,10 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: + if non_linearity is None: self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) self.upsample = None if add_upsample: @@ -232,10 +221,7 @@ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): super().__init__() self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) - if act_fn == "silu": - self.final_conv1d_act = nn.SiLU() - if act_fn == "mish": - self.final_conv1d_act = nn.Mish() + self.final_conv1d_act = get_activation(act_fn) self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) def forward(self, hidden_states, temb=None): diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 106346070d94..dda21fd80479 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -16,12 +16,12 @@ import torch import torch.nn as nn -import torch.nn.functional as F import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging +from .activations import get_activation from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import ( GaussianFourierProjection, @@ -338,16 +338,8 @@ def __init__( if time_embedding_act_fn is None: self.time_embed_act = None - elif time_embedding_act_fn == "swish": - self.time_embed_act = lambda x: F.silu(x) - elif time_embedding_act_fn == "mish": - self.time_embed_act = nn.Mish() - elif time_embedding_act_fn == "silu": - self.time_embed_act = nn.SiLU() - elif time_embedding_act_fn == "gelu": - self.time_embed_act = nn.GELU() else: - raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") + self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -501,16 +493,7 @@ def __init__( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - if act_fn == "swish": - self.conv_act = lambda x: F.silu(x) - elif act_fn == "mish": - self.conv_act = nn.Mish() - elif act_fn == "silu": - self.conv_act = nn.SiLU() - elif act_fn == "gelu": - self.conv_act = nn.GELU() - else: - raise ValueError(f"Unsupported activation function: {act_fn}") + self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index a0dbdaa75230..f11729451299 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -7,6 +7,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin +from ...models.activations import get_activation from ...models.attention import Attention from ...models.attention_processor import ( AttentionProcessor, @@ -441,16 +442,8 @@ def __init__( if time_embedding_act_fn is None: self.time_embed_act = None - elif time_embedding_act_fn == "swish": - self.time_embed_act = lambda x: F.silu(x) - elif time_embedding_act_fn == "mish": - self.time_embed_act = nn.Mish() - elif time_embedding_act_fn == "silu": - self.time_embed_act = nn.SiLU() - elif time_embedding_act_fn == "gelu": - self.time_embed_act = nn.GELU() else: - raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") + self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -604,16 +597,7 @@ def __init__( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - if act_fn == "swish": - self.conv_act = lambda x: F.silu(x) - elif act_fn == "mish": - self.conv_act = nn.Mish() - elif act_fn == "silu": - self.conv_act = nn.SiLU() - elif act_fn == "gelu": - self.conv_act = nn.GELU() - else: - raise ValueError(f"Unsupported activation function: {act_fn}") + self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None diff --git a/tests/models/test_activations.py b/tests/models/test_activations.py new file mode 100644 index 000000000000..4e8e51453e98 --- /dev/null +++ b/tests/models/test_activations.py @@ -0,0 +1,48 @@ +import unittest + +import torch +from torch import nn + +from diffusers.models.activations import get_activation + + +class ActivationsTests(unittest.TestCase): + def test_swish(self): + act = get_activation("swish") + + self.assertIsInstance(act, nn.SiLU) + + self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0) + self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20) + + def test_silu(self): + act = get_activation("silu") + + self.assertIsInstance(act, nn.SiLU) + + self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0) + self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20) + + def test_mish(self): + act = get_activation("mish") + + self.assertIsInstance(act, nn.Mish) + + self.assertEqual(act(torch.tensor(-200, dtype=torch.float32)).item(), 0) + self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20) + + def test_gelu(self): + act = get_activation("gelu") + + self.assertIsInstance(act, nn.GELU) + + self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0) + self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)