Skip to content

Commit be9d632

Browse files
committed
tests
1 parent 57be389 commit be9d632

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

tests/models/test_activations.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import unittest
2+
3+
import torch
4+
from torch import nn
5+
6+
from diffusers.models.activations import get_activation
7+
8+
9+
class ActivationsTests(unittest.TestCase):
10+
def test_swish(self):
11+
act = get_activation("swish")
12+
13+
self.assertIsInstance(act, nn.SiLU)
14+
15+
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
16+
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
17+
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
18+
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
19+
20+
def test_silu(self):
21+
act = get_activation("silu")
22+
23+
self.assertIsInstance(act, nn.SiLU)
24+
25+
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
26+
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
27+
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
28+
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
29+
30+
def test_mish(self):
31+
act = get_activation("mish")
32+
33+
self.assertIsInstance(act, nn.Mish)
34+
35+
self.assertEqual(act(torch.tensor(-200, dtype=torch.float32)).item(), 0)
36+
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
37+
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
38+
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
39+
40+
def test_gelu(self):
41+
act = get_activation("gelu")
42+
43+
self.assertIsInstance(act, nn.GELU)
44+
45+
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
46+
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
47+
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
48+
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)

0 commit comments

Comments
 (0)