Skip to content

move activation dispatches into helper function #3656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/diffusers/models/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torch import nn


def get_activation(act_fn):
if act_fn in ["swish", "silu"]:
return nn.SiLU()
elif act_fn == "mish":
return nn.Mish()
elif act_fn == "gelu":
return nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
15 changes: 6 additions & 9 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch import nn

from ..utils import maybe_allow_in_graph
from .activations import get_activation
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings

Expand Down Expand Up @@ -345,15 +346,11 @@ def __init__(
super().__init__()
self.num_groups = num_groups
self.eps = eps
self.act = None
if act_fn == "swish":
self.act = lambda x: F.silu(x)
elif act_fn == "mish":
self.act = nn.Mish()
elif act_fn == "silu":
self.act = nn.SiLU()
elif act_fn == "gelu":
self.act = nn.GELU()

if act_fn is None:
self.act = None
else:
self.act = get_activation(act_fn)

self.linear = nn.Linear(embedding_dim, out_dim * 2)

Expand Down
19 changes: 4 additions & 15 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
from torch import nn

from .activations import get_activation


def get_timestep_embedding(
timesteps: torch.Tensor,
Expand Down Expand Up @@ -171,14 +173,7 @@ def __init__(
else:
self.cond_proj = None

if act_fn == "silu":
self.act = nn.SiLU()
elif act_fn == "mish":
self.act = nn.Mish()
elif act_fn == "gelu":
self.act = nn.GELU()
else:
raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
self.act = get_activation(act_fn)

if out_dim is not None:
time_embed_dim_out = out_dim
Expand All @@ -188,14 +183,8 @@ def __init__(

if post_act_fn is None:
self.post_act = None
elif post_act_fn == "silu":
self.post_act = nn.SiLU()
elif post_act_fn == "mish":
self.post_act = nn.Mish()
elif post_act_fn == "gelu":
self.post_act = nn.GELU()
else:
raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
self.post_act = get_activation(post_act_fn)

def forward(self, sample, condition=None):
if condition is not None:
Expand Down
15 changes: 2 additions & 13 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn as nn
import torch.nn.functional as F

from .activations import get_activation
from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm

Expand Down Expand Up @@ -558,14 +559,7 @@ def __init__(
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)

if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
self.nonlinearity = get_activation(non_linearity)

self.upsample = self.downsample = None
if self.up:
Expand Down Expand Up @@ -646,11 +640,6 @@ def forward(self, input_tensor, temb):
return output_tensor


class Mish(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))

Comment on lines -649 to -652
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used


# unet_rl.py
def rearrange_dims(tensor):
if len(tensor.shape) == 2:
Expand Down
36 changes: 11 additions & 25 deletions src/diffusers/models/unet_1d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.nn.functional as F
from torch import nn

from .activations import get_activation
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims


Expand Down Expand Up @@ -55,14 +56,10 @@ def __init__(

self.resnets = nn.ModuleList(resnets)

if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)

self.downsample = None
if add_downsample:
Expand Down Expand Up @@ -119,14 +116,10 @@ def __init__(

self.resnets = nn.ModuleList(resnets)

if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)

self.upsample = None
if add_upsample:
Expand Down Expand Up @@ -194,14 +187,10 @@ def __init__(

self.resnets = nn.ModuleList(resnets)

if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)

self.upsample = None
if add_upsample:
Expand Down Expand Up @@ -232,10 +221,7 @@ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
if act_fn == "silu":
self.final_conv1d_act = nn.SiLU()
if act_fn == "mish":
self.final_conv1d_act = nn.Mish()
self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)

def forward(self, hidden_states, temb=None):
Expand Down
23 changes: 3 additions & 20 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint

from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from .activations import get_activation
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import (
GaussianFourierProjection,
Expand Down Expand Up @@ -338,16 +338,8 @@ def __init__(

if time_embedding_act_fn is None:
self.time_embed_act = None
elif time_embedding_act_fn == "swish":
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == "mish":
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == "silu":
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == "gelu":
self.time_embed_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
self.time_embed_act = get_activation(time_embedding_act_fn)

self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
Expand Down Expand Up @@ -501,16 +493,7 @@ def __init__(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)

if act_fn == "swish":
self.conv_act = lambda x: F.silu(x)
elif act_fn == "mish":
self.conv_act = nn.Mish()
elif act_fn == "silu":
self.conv_act = nn.SiLU()
elif act_fn == "gelu":
self.conv_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
self.conv_act = get_activation(act_fn)

else:
self.conv_norm_out = None
Expand Down
22 changes: 3 additions & 19 deletions src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.activations import get_activation
from ...models.attention import Attention
from ...models.attention_processor import (
AttentionProcessor,
Expand Down Expand Up @@ -441,16 +442,8 @@ def __init__(

if time_embedding_act_fn is None:
self.time_embed_act = None
elif time_embedding_act_fn == "swish":
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == "mish":
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == "silu":
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == "gelu":
self.time_embed_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
self.time_embed_act = get_activation(time_embedding_act_fn)

self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
Expand Down Expand Up @@ -604,16 +597,7 @@ def __init__(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)

if act_fn == "swish":
self.conv_act = lambda x: F.silu(x)
elif act_fn == "mish":
self.conv_act = nn.Mish()
elif act_fn == "silu":
self.conv_act = nn.SiLU()
elif act_fn == "gelu":
self.conv_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
self.conv_act = get_activation(act_fn)

else:
self.conv_norm_out = None
Expand Down
48 changes: 48 additions & 0 deletions tests/models/test_activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import unittest

import torch
from torch import nn

from diffusers.models.activations import get_activation


class ActivationsTests(unittest.TestCase):
def test_swish(self):
act = get_activation("swish")

self.assertIsInstance(act, nn.SiLU)

self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)

def test_silu(self):
act = get_activation("silu")

self.assertIsInstance(act, nn.SiLU)

self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)

def test_mish(self):
act = get_activation("mish")

self.assertIsInstance(act, nn.Mish)

self.assertEqual(act(torch.tensor(-200, dtype=torch.float32)).item(), 0)
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)

def test_gelu(self):
act = get_activation("gelu")

self.assertIsInstance(act, nn.GELU)

self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)