Skip to content

Commit 41ae670

Browse files
move activation dispatches into helper function (#3656)
* move activation dispatches into helper function * tests
1 parent 462956b commit 41ae670

File tree

8 files changed

+89
-101
lines changed

8 files changed

+89
-101
lines changed

src/diffusers/models/activations.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from torch import nn
2+
3+
4+
def get_activation(act_fn):
5+
if act_fn in ["swish", "silu"]:
6+
return nn.SiLU()
7+
elif act_fn == "mish":
8+
return nn.Mish()
9+
elif act_fn == "gelu":
10+
return nn.GELU()
11+
else:
12+
raise ValueError(f"Unsupported activation function: {act_fn}")

src/diffusers/models/attention.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch import nn
1919

2020
from ..utils import maybe_allow_in_graph
21+
from .activations import get_activation
2122
from .attention_processor import Attention
2223
from .embeddings import CombinedTimestepLabelEmbeddings
2324

@@ -345,15 +346,11 @@ def __init__(
345346
super().__init__()
346347
self.num_groups = num_groups
347348
self.eps = eps
348-
self.act = None
349-
if act_fn == "swish":
350-
self.act = lambda x: F.silu(x)
351-
elif act_fn == "mish":
352-
self.act = nn.Mish()
353-
elif act_fn == "silu":
354-
self.act = nn.SiLU()
355-
elif act_fn == "gelu":
356-
self.act = nn.GELU()
349+
350+
if act_fn is None:
351+
self.act = None
352+
else:
353+
self.act = get_activation(act_fn)
357354

358355
self.linear = nn.Linear(embedding_dim, out_dim * 2)
359356

src/diffusers/models/embeddings.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import torch
1919
from torch import nn
2020

21+
from .activations import get_activation
22+
2123

2224
def get_timestep_embedding(
2325
timesteps: torch.Tensor,
@@ -171,14 +173,7 @@ def __init__(
171173
else:
172174
self.cond_proj = None
173175

174-
if act_fn == "silu":
175-
self.act = nn.SiLU()
176-
elif act_fn == "mish":
177-
self.act = nn.Mish()
178-
elif act_fn == "gelu":
179-
self.act = nn.GELU()
180-
else:
181-
raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
176+
self.act = get_activation(act_fn)
182177

183178
if out_dim is not None:
184179
time_embed_dim_out = out_dim
@@ -188,14 +183,8 @@ def __init__(
188183

189184
if post_act_fn is None:
190185
self.post_act = None
191-
elif post_act_fn == "silu":
192-
self.post_act = nn.SiLU()
193-
elif post_act_fn == "mish":
194-
self.post_act = nn.Mish()
195-
elif post_act_fn == "gelu":
196-
self.post_act = nn.GELU()
197186
else:
198-
raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
187+
self.post_act = get_activation(post_act_fn)
199188

200189
def forward(self, sample, condition=None):
201190
if condition is not None:

src/diffusers/models/resnet.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.nn as nn
2121
import torch.nn.functional as F
2222

23+
from .activations import get_activation
2324
from .attention import AdaGroupNorm
2425
from .attention_processor import SpatialNorm
2526

@@ -558,14 +559,7 @@ def __init__(
558559
conv_2d_out_channels = conv_2d_out_channels or out_channels
559560
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
560561

561-
if non_linearity == "swish":
562-
self.nonlinearity = lambda x: F.silu(x)
563-
elif non_linearity == "mish":
564-
self.nonlinearity = nn.Mish()
565-
elif non_linearity == "silu":
566-
self.nonlinearity = nn.SiLU()
567-
elif non_linearity == "gelu":
568-
self.nonlinearity = nn.GELU()
562+
self.nonlinearity = get_activation(non_linearity)
569563

570564
self.upsample = self.downsample = None
571565
if self.up:
@@ -646,11 +640,6 @@ def forward(self, input_tensor, temb):
646640
return output_tensor
647641

648642

649-
class Mish(torch.nn.Module):
650-
def forward(self, hidden_states):
651-
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
652-
653-
654643
# unet_rl.py
655644
def rearrange_dims(tensor):
656645
if len(tensor.shape) == 2:

src/diffusers/models/unet_1d_blocks.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch.nn.functional as F
1818
from torch import nn
1919

20+
from .activations import get_activation
2021
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
2122

2223

@@ -55,14 +56,10 @@ def __init__(
5556

5657
self.resnets = nn.ModuleList(resnets)
5758

58-
if non_linearity == "swish":
59-
self.nonlinearity = lambda x: F.silu(x)
60-
elif non_linearity == "mish":
61-
self.nonlinearity = nn.Mish()
62-
elif non_linearity == "silu":
63-
self.nonlinearity = nn.SiLU()
64-
else:
59+
if non_linearity is None:
6560
self.nonlinearity = None
61+
else:
62+
self.nonlinearity = get_activation(non_linearity)
6663

6764
self.downsample = None
6865
if add_downsample:
@@ -119,14 +116,10 @@ def __init__(
119116

120117
self.resnets = nn.ModuleList(resnets)
121118

122-
if non_linearity == "swish":
123-
self.nonlinearity = lambda x: F.silu(x)
124-
elif non_linearity == "mish":
125-
self.nonlinearity = nn.Mish()
126-
elif non_linearity == "silu":
127-
self.nonlinearity = nn.SiLU()
128-
else:
119+
if non_linearity is None:
129120
self.nonlinearity = None
121+
else:
122+
self.nonlinearity = get_activation(non_linearity)
130123

131124
self.upsample = None
132125
if add_upsample:
@@ -194,14 +187,10 @@ def __init__(
194187

195188
self.resnets = nn.ModuleList(resnets)
196189

197-
if non_linearity == "swish":
198-
self.nonlinearity = lambda x: F.silu(x)
199-
elif non_linearity == "mish":
200-
self.nonlinearity = nn.Mish()
201-
elif non_linearity == "silu":
202-
self.nonlinearity = nn.SiLU()
203-
else:
190+
if non_linearity is None:
204191
self.nonlinearity = None
192+
else:
193+
self.nonlinearity = get_activation(non_linearity)
205194

206195
self.upsample = None
207196
if add_upsample:
@@ -232,10 +221,7 @@ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
232221
super().__init__()
233222
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
234223
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
235-
if act_fn == "silu":
236-
self.final_conv1d_act = nn.SiLU()
237-
if act_fn == "mish":
238-
self.final_conv1d_act = nn.Mish()
224+
self.final_conv1d_act = get_activation(act_fn)
239225
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
240226

241227
def forward(self, hidden_states, temb=None):

src/diffusers/models/unet_2d_condition.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
import torch
1818
import torch.nn as nn
19-
import torch.nn.functional as F
2019
import torch.utils.checkpoint
2120

2221
from ..configuration_utils import ConfigMixin, register_to_config
2322
from ..loaders import UNet2DConditionLoadersMixin
2423
from ..utils import BaseOutput, logging
24+
from .activations import get_activation
2525
from .attention_processor import AttentionProcessor, AttnProcessor
2626
from .embeddings import (
2727
GaussianFourierProjection,
@@ -338,16 +338,8 @@ def __init__(
338338

339339
if time_embedding_act_fn is None:
340340
self.time_embed_act = None
341-
elif time_embedding_act_fn == "swish":
342-
self.time_embed_act = lambda x: F.silu(x)
343-
elif time_embedding_act_fn == "mish":
344-
self.time_embed_act = nn.Mish()
345-
elif time_embedding_act_fn == "silu":
346-
self.time_embed_act = nn.SiLU()
347-
elif time_embedding_act_fn == "gelu":
348-
self.time_embed_act = nn.GELU()
349341
else:
350-
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
342+
self.time_embed_act = get_activation(time_embedding_act_fn)
351343

352344
self.down_blocks = nn.ModuleList([])
353345
self.up_blocks = nn.ModuleList([])
@@ -501,16 +493,7 @@ def __init__(
501493
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
502494
)
503495

504-
if act_fn == "swish":
505-
self.conv_act = lambda x: F.silu(x)
506-
elif act_fn == "mish":
507-
self.conv_act = nn.Mish()
508-
elif act_fn == "silu":
509-
self.conv_act = nn.SiLU()
510-
elif act_fn == "gelu":
511-
self.conv_act = nn.GELU()
512-
else:
513-
raise ValueError(f"Unsupported activation function: {act_fn}")
496+
self.conv_act = get_activation(act_fn)
514497

515498
else:
516499
self.conv_norm_out = None

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from ...configuration_utils import ConfigMixin, register_to_config
99
from ...models import ModelMixin
10+
from ...models.activations import get_activation
1011
from ...models.attention import Attention
1112
from ...models.attention_processor import (
1213
AttentionProcessor,
@@ -441,16 +442,8 @@ def __init__(
441442

442443
if time_embedding_act_fn is None:
443444
self.time_embed_act = None
444-
elif time_embedding_act_fn == "swish":
445-
self.time_embed_act = lambda x: F.silu(x)
446-
elif time_embedding_act_fn == "mish":
447-
self.time_embed_act = nn.Mish()
448-
elif time_embedding_act_fn == "silu":
449-
self.time_embed_act = nn.SiLU()
450-
elif time_embedding_act_fn == "gelu":
451-
self.time_embed_act = nn.GELU()
452445
else:
453-
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
446+
self.time_embed_act = get_activation(time_embedding_act_fn)
454447

455448
self.down_blocks = nn.ModuleList([])
456449
self.up_blocks = nn.ModuleList([])
@@ -604,16 +597,7 @@ def __init__(
604597
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
605598
)
606599

607-
if act_fn == "swish":
608-
self.conv_act = lambda x: F.silu(x)
609-
elif act_fn == "mish":
610-
self.conv_act = nn.Mish()
611-
elif act_fn == "silu":
612-
self.conv_act = nn.SiLU()
613-
elif act_fn == "gelu":
614-
self.conv_act = nn.GELU()
615-
else:
616-
raise ValueError(f"Unsupported activation function: {act_fn}")
600+
self.conv_act = get_activation(act_fn)
617601

618602
else:
619603
self.conv_norm_out = None

tests/models/test_activations.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import unittest
2+
3+
import torch
4+
from torch import nn
5+
6+
from diffusers.models.activations import get_activation
7+
8+
9+
class ActivationsTests(unittest.TestCase):
10+
def test_swish(self):
11+
act = get_activation("swish")
12+
13+
self.assertIsInstance(act, nn.SiLU)
14+
15+
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
16+
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
17+
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
18+
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
19+
20+
def test_silu(self):
21+
act = get_activation("silu")
22+
23+
self.assertIsInstance(act, nn.SiLU)
24+
25+
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
26+
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
27+
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
28+
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
29+
30+
def test_mish(self):
31+
act = get_activation("mish")
32+
33+
self.assertIsInstance(act, nn.Mish)
34+
35+
self.assertEqual(act(torch.tensor(-200, dtype=torch.float32)).item(), 0)
36+
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
37+
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
38+
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
39+
40+
def test_gelu(self):
41+
act = get_activation("gelu")
42+
43+
self.assertIsInstance(act, nn.GELU)
44+
45+
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
46+
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
47+
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
48+
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)

0 commit comments

Comments
 (0)