@@ -640,9 +640,9 @@ def forward(
640
640
reshaped = key_padding_mask .unsqueeze (1 ).unsqueeze (2 ).to (torch .bool )
641
641
attn_weights = attn_weights .masked_fill (reshaped , float ("-inf" ))
642
642
attn_weights = attn_weights .view (bsz * self .num_heads , tgt_len , src_len )
643
- attn_weights_float = F .softmax (attn_weights , dim = - 1 , dtype = torch . float32 )
644
- attn_weights = attn_weights_float . type_as (attn_weights )
645
- attn_probs = F . dropout ( attn_weights_float , p = self . dropout , training = self . training ,)
643
+ attn_weights = F .softmax (attn_weights , dim = - 1 )
644
+ attn_probs = F . dropout (attn_weights , p = self . dropout , training = self . training , )
645
+
646
646
assert v is not None
647
647
attn_output = torch .bmm (attn_probs , v )
648
648
assert attn_output .size () == (bsz * self .num_heads , tgt_len , self .head_dim )
@@ -696,7 +696,7 @@ def _cat_prev_key_padding_mask(
696
696
elif prev_key_padding_mask is not None :
697
697
filler = torch .zeros (batch_size , src_len - prev_key_padding_mask .size (1 ))
698
698
if prev_key_padding_mask .is_cuda :
699
- filler = filler .cuda ( )
699
+ filler = filler .to ( prev_key_padding_mask . device )
700
700
new_key_padding_mask = torch .cat ([prev_key_padding_mask .float (), filler .float ()], dim = 1 )
701
701
elif key_padding_mask is not None :
702
702
filler = torch .zeros (batch_size , src_len - key_padding_mask .size (1 ))
0 commit comments