Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 8 additions & 19 deletions torchaudio/models/wav2vec2/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def __init__(

self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = torch.nn.Dropout(dropout)
self.dropout = dropout
self.head_dim = head_dim

self.scaling = self.head_dim**-0.5
Expand Down Expand Up @@ -304,25 +304,14 @@ def forward(

shape = (batch_size, length, self.num_heads, self.head_dim)
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L
k = self.k_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd

# scale down q to avoid value overflow.
weights = (self.scaling * q) @ k # B, nH, L, L
if attention_mask is not None:
weights += attention_mask
# subtracting a constant value from the tensor won't change the output of softmax.
# apply the subtraction to avoid value overflow in torch.nn.functional.softmax.
# for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778
weights = weights - weights.max(dim=-1, keepdim=True)[0]

weights = torch.nn.functional.softmax(weights, dim=-1)
weights = self.dropout(weights)

output = weights @ v # B, nH, L, Hd
output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)

output = self.out_proj(output)
dropout = self.dropout if self.training else 0.0
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False
)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
output = self.out_proj(attn_output)
return output, None # Necessary for compatibility with WavLMSelAttention


Expand Down