@@ -632,7 +632,9 @@ def generate(
632
632
cur_bsz = cur_bsz ,
633
633
batch_idx_map = batch_idx_map ,
634
634
)
635
- time_offset = seek .to (torch .float64 ) * time_precision / input_stride
635
+ time_offset = (
636
+ seek .to (torch .float32 if device .type == "mps" else torch .float64 ) * time_precision / input_stride
637
+ )
636
638
seek_num_frames = (max_frames - seek ).clamp (max = num_segment_frames )
637
639
638
640
# 6.2 cut out next 30s segment from input features
@@ -1805,6 +1807,7 @@ def _retrieve_segment(
1805
1807
timestamp_segment_indices = torch .where (timestamp_tokens [:- 1 ] & timestamp_tokens [1 :])[0 ]
1806
1808
timestamp_segment_indices .add_ (1 )
1807
1809
token_timestamps = seek_outputs [idx ]["token_timestamps" ] if return_token_timestamps else []
1810
+ device = seek_sequence .device
1808
1811
1809
1812
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
1810
1813
# "end of segment" prediction and slice the decoding into segments accordingly
@@ -1828,8 +1831,12 @@ def _retrieve_segment(
1828
1831
end_timestamp_pos = sliced_tokens [idx_sliced_tokens ] - timestamp_begin
1829
1832
segments .append (
1830
1833
{
1831
- "start" : time_offset [prev_idx ] + start_timestamp_pos .to (torch .float64 ) * time_precision ,
1832
- "end" : time_offset [prev_idx ] + end_timestamp_pos .to (torch .float64 ) * time_precision ,
1834
+ "start" : time_offset [prev_idx ]
1835
+ + start_timestamp_pos .to (torch .float32 if device .type == "mps" else torch .float64 )
1836
+ * time_precision ,
1837
+ "end" : time_offset [prev_idx ]
1838
+ + end_timestamp_pos .to (torch .float32 if device .type == "mps" else torch .float64 )
1839
+ * time_precision ,
1833
1840
"tokens" : sliced_tokens ,
1834
1841
"result" : seek_outputs [idx ],
1835
1842
}
@@ -1856,7 +1863,9 @@ def _retrieve_segment(
1856
1863
last_timestamp_pos = int (seek_num_frames [prev_idx ] * time_precision_features / time_precision )
1857
1864
if timestamps .numel () > 0 and timestamps [- 1 ] != timestamp_begin :
1858
1865
# no consecutive timestamps but it has a timestamp; use the last one.
1859
- last_timestamp_pos = (timestamps [- 1 ] - timestamp_begin ).to (torch .float64 )
1866
+ last_timestamp_pos = (timestamps [- 1 ] - timestamp_begin ).to (
1867
+ torch .float32 if device .type == "mps" else torch .float64
1868
+ )
1860
1869
segments = [
1861
1870
{
1862
1871
"start" : time_offset [prev_idx ],
0 commit comments