Skip to content

Commit 1741d74

Browse files
authored
Merge pull request #3145 from sshleifer/bartfp16
[Bart] FP16 Support
2 parents bbabbc1 + 14d4058 commit 1741d74

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/transformers/modeling_bart.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -640,9 +640,9 @@ def forward(
640640
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool)
641641
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
642642
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
643-
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
644-
attn_weights = attn_weights_float.type_as(attn_weights)
645-
attn_probs = F.dropout(attn_weights_float, p=self.dropout, training=self.training,)
643+
attn_weights = F.softmax(attn_weights, dim=-1)
644+
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
645+
646646
assert v is not None
647647
attn_output = torch.bmm(attn_probs, v)
648648
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
@@ -696,7 +696,7 @@ def _cat_prev_key_padding_mask(
696696
elif prev_key_padding_mask is not None:
697697
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
698698
if prev_key_padding_mask.is_cuda:
699-
filler = filler.cuda()
699+
filler = filler.to(prev_key_padding_mask.device)
700700
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
701701
elif key_padding_mask is not None:
702702
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))

tests/test_modeling_bart.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,13 @@ def test_tokenization(self):
294294
bart_toks = tokenizer.encode(ex, return_tensors="pt")
295295
_assert_tensors_equal(desired_result.long(), bart_toks, prefix=ex)
296296

297+
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
298+
def test_generate_fp16(self):
299+
config, input_ids, batch_size = self._get_config_and_data(output_past=True)
300+
attention_mask = input_ids.ne(1)
301+
lm_model = BartForMaskedLM(config).eval().to(torch_device).half()
302+
lm_model.generate(input_ids, attention_mask)
303+
297304

298305
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
299306
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""

0 commit comments

Comments
 (0)