-
Notifications
You must be signed in to change notification settings - Fork 6.1k
move activation dispatches into helper function #3656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
williamberman
merged 2 commits into
huggingface:main
from
williamberman:activations_refactor
Jun 5, 2023
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from torch import nn | ||
|
||
|
||
def get_activation(act_fn): | ||
if act_fn in ["swish", "silu"]: | ||
return nn.SiLU() | ||
elif act_fn == "mish": | ||
return nn.Mish() | ||
elif act_fn == "gelu": | ||
return nn.GELU() | ||
else: | ||
raise ValueError(f"Unsupported activation function: {act_fn}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from .activations import get_activation | ||
from .attention import AdaGroupNorm | ||
from .attention_processor import SpatialNorm | ||
|
||
|
@@ -558,14 +559,7 @@ def __init__( | |
conv_2d_out_channels = conv_2d_out_channels or out_channels | ||
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) | ||
|
||
if non_linearity == "swish": | ||
self.nonlinearity = lambda x: F.silu(x) | ||
elif non_linearity == "mish": | ||
self.nonlinearity = nn.Mish() | ||
elif non_linearity == "silu": | ||
self.nonlinearity = nn.SiLU() | ||
elif non_linearity == "gelu": | ||
self.nonlinearity = nn.GELU() | ||
self.nonlinearity = get_activation(non_linearity) | ||
|
||
self.upsample = self.downsample = None | ||
if self.up: | ||
|
@@ -646,11 +640,6 @@ def forward(self, input_tensor, temb): | |
return output_tensor | ||
|
||
|
||
class Mish(torch.nn.Module): | ||
def forward(self, hidden_states): | ||
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) | ||
|
||
Comment on lines
-649
to
-652
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not used |
||
|
||
# unet_rl.py | ||
def rearrange_dims(tensor): | ||
if len(tensor.shape) == 2: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import unittest | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from diffusers.models.activations import get_activation | ||
|
||
|
||
class ActivationsTests(unittest.TestCase): | ||
def test_swish(self): | ||
act = get_activation("swish") | ||
|
||
self.assertIsInstance(act, nn.SiLU) | ||
|
||
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0) | ||
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0) | ||
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0) | ||
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20) | ||
|
||
def test_silu(self): | ||
act = get_activation("silu") | ||
|
||
self.assertIsInstance(act, nn.SiLU) | ||
|
||
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0) | ||
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0) | ||
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0) | ||
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20) | ||
|
||
def test_mish(self): | ||
act = get_activation("mish") | ||
|
||
self.assertIsInstance(act, nn.Mish) | ||
|
||
self.assertEqual(act(torch.tensor(-200, dtype=torch.float32)).item(), 0) | ||
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0) | ||
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0) | ||
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20) | ||
|
||
def test_gelu(self): | ||
act = get_activation("gelu") | ||
|
||
self.assertIsInstance(act, nn.GELU) | ||
|
||
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0) | ||
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0) | ||
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0) | ||
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.