Skip to content

Commit 4cd46de

Browse files
committed
fix transformerblock
Signed-off-by: yang-ze-kang <[email protected]>
1 parent 53382d8 commit 4cd46de

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

monai/networks/blocks/transformerblock.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,16 @@ def __init__(
8080
self.norm2 = nn.LayerNorm(hidden_size)
8181
self.with_cross_attention = with_cross_attention
8282

83-
self.norm_cross_attn = nn.LayerNorm(hidden_size)
84-
self.cross_attn = CrossAttentionBlock(
85-
hidden_size=hidden_size,
86-
num_heads=num_heads,
87-
dropout_rate=dropout_rate,
88-
qkv_bias=qkv_bias,
89-
causal=False,
90-
use_flash_attention=use_flash_attention,
91-
)
83+
if with_cross_attention:
84+
self.norm_cross_attn = nn.LayerNorm(hidden_size)
85+
self.cross_attn = CrossAttentionBlock(
86+
hidden_size=hidden_size,
87+
num_heads=num_heads,
88+
dropout_rate=dropout_rate,
89+
qkv_bias=qkv_bias,
90+
causal=False,
91+
use_flash_attention=use_flash_attention,
92+
)
9293

9394
def forward(
9495
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None

0 commit comments

Comments
 (0)