Skip to content

Commit 0a68a1f

Browse files
committed
add test for AttentionBlock, SpatialTransformer
1 parent 5164c9f commit 0a68a1f

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

src/diffusers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .attention import AttentionBlock, SpatialTransformer
1516
from .unet_2d import UNet2DModel
1617
from .unet_2d_condition import UNet2DConditionModel
1718
from .vae import AutoencoderKL, VQModel

tests/test_layers_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import torch
2121

22+
from diffusers.models.attention import AttentionBlock, SpatialTransformer
2223
from diffusers.models.embeddings import get_timestep_embedding
2324
from diffusers.models.resnet import Downsample2D, Upsample2D
2425

@@ -216,3 +217,45 @@ def test_downsample_with_conv_out_dim(self):
216217
output_slice = downsampled[0, -1, -3:, -3:]
217218
expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522])
218219
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
220+
221+
222+
class AttentionBlockTests(unittest.TestCase):
223+
def test_attention_block_default(self):
224+
torch.manual_seed(0)
225+
sample = torch.randn(1, 32, 64, 64)
226+
attentionBlock = AttentionBlock(
227+
channels=32,
228+
num_head_channels=1,
229+
rescale_output_factor=1.0,
230+
eps=1e-6,
231+
num_groups=32,
232+
)
233+
with torch.no_grad():
234+
attention_scores = attentionBlock(sample)
235+
236+
assert attention_scores.shape == (1, 32, 64, 64)
237+
output_slice = attention_scores[0, -1, -3:, -3:]
238+
239+
expected_slice = torch.tensor([-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427])
240+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1)
241+
242+
243+
class SpatialTransformerTests(unittest.TestCase):
244+
def test_spatial_transformer_default(self):
245+
torch.manual_seed(0)
246+
sample = torch.randn(1, 32, 64, 64)
247+
spatialTransformerBlock = SpatialTransformer(
248+
in_channels=32,
249+
n_heads=1,
250+
d_head=32,
251+
dropout=0.0,
252+
context_dim=None,
253+
)
254+
with torch.no_grad():
255+
attention_scores = spatialTransformerBlock(sample)
256+
257+
assert attention_scores.shape == (1, 32, 64, 64)
258+
output_slice = attention_scores[0, -1, -3:, -3:]
259+
260+
expected_slice = torch.tensor([-1.2447, -0.0137, -0.9559, -1.5223, 0.6991, -1.0126, -2.0974, 0.8921, -1.0201])
261+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1)

0 commit comments

Comments
 (0)