Skip to content

Commit 21d0442

Browse files
jiqing-fengBernardZach
authored andcommitted
enable QA bf16 pipeline (huggingface#34483)
* enable QA bf16 pipeline * add tests
1 parent c0c9206 commit 21d0442

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

src/transformers/pipelines/question_answering.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,8 +540,14 @@ def postprocess(
540540
min_null_score = 1000000 # large and positive
541541
answers = []
542542
for output in model_outputs:
543-
start_ = output["start"]
544-
end_ = output["end"]
543+
if self.framework == "pt" and output["start"].dtype == torch.bfloat16:
544+
start_ = output["start"].to(torch.float32)
545+
else:
546+
start_ = output["start"]
547+
if self.framework == "pt" and output["start"].dtype == torch.bfloat16:
548+
end_ = output["end"].to(torch.float32)
549+
else:
550+
end_ = output["end"]
545551
example = output["example"]
546552
p_mask = output["p_mask"]
547553
attention_mask = (

tests/pipelines/test_pipelines_question_answering.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,18 @@
2727
from transformers.testing_utils import (
2828
compare_pipeline_output_to_hub_spec,
2929
is_pipeline_test,
30+
is_torch_available,
3031
nested_simplify,
3132
require_tf,
3233
require_torch,
3334
require_torch_or_tf,
3435
slow,
3536
)
3637

38+
39+
if is_torch_available():
40+
import torch
41+
3742
from .test_pipelines_common import ANY
3843

3944

@@ -165,6 +170,34 @@ def test_small_model_pt(self):
165170

166171
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
167172

173+
@require_torch
174+
def test_small_model_pt_fp16(self):
175+
question_answerer = pipeline(
176+
"question-answering",
177+
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
178+
torch_dtype=torch.float16,
179+
)
180+
181+
outputs = question_answerer(
182+
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
183+
)
184+
185+
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
186+
187+
@require_torch
188+
def test_small_model_pt_bf16(self):
189+
question_answerer = pipeline(
190+
"question-answering",
191+
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
192+
torch_dtype=torch.bfloat16,
193+
)
194+
195+
outputs = question_answerer(
196+
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
197+
)
198+
199+
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
200+
168201
@require_torch
169202
def test_small_model_pt_iterator(self):
170203
# https://github.com/huggingface/transformers/issues/18510

0 commit comments

Comments
 (0)