Skip to content

Commit 9feae5f

Browse files
authored
[Whisper] patch float type on mps (#35295)
* fix float type on mps * make
1 parent d5b81e1 commit 9feae5f

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/transformers/models/whisper/generation_whisper.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,9 @@ def generate(
632632
cur_bsz=cur_bsz,
633633
batch_idx_map=batch_idx_map,
634634
)
635-
time_offset = seek.to(torch.float64) * time_precision / input_stride
635+
time_offset = (
636+
seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
637+
)
636638
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
637639

638640
# 6.2 cut out next 30s segment from input features
@@ -1805,6 +1807,7 @@ def _retrieve_segment(
18051807
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
18061808
timestamp_segment_indices.add_(1)
18071809
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
1810+
device = seek_sequence.device
18081811

18091812
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
18101813
# "end of segment" prediction and slice the decoding into segments accordingly
@@ -1828,8 +1831,12 @@ def _retrieve_segment(
18281831
end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
18291832
segments.append(
18301833
{
1831-
"start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision,
1832-
"end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision,
1834+
"start": time_offset[prev_idx]
1835+
+ start_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
1836+
* time_precision,
1837+
"end": time_offset[prev_idx]
1838+
+ end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
1839+
* time_precision,
18331840
"tokens": sliced_tokens,
18341841
"result": seek_outputs[idx],
18351842
}
@@ -1856,7 +1863,9 @@ def _retrieve_segment(
18561863
last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
18571864
if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin:
18581865
# no consecutive timestamps but it has a timestamp; use the last one.
1859-
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64)
1866+
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(
1867+
torch.float32 if device.type == "mps" else torch.float64
1868+
)
18601869
segments = [
18611870
{
18621871
"start": time_offset[prev_idx],

0 commit comments

Comments
 (0)