Skip to content

Commit 4c5ca93

Browse files
committed
vad : use uint64_t for time mapping
This commit changes the type of the `processed_time` and `original_time` fields in the `vad_time_mapping` struct from `double` to `uint64_t`. The motivation for this change is made to improve precision and avoid floating-point inaccuracies and also be consistent with other part of the code base that use `uint64_t` for time representation. This is a part of a refactoring where I'm also going to change the vad_segment_info struct to use `uint64_t` for the start and end times. This is the reason for the not so pleasant conversion and casts in the code at the moment.
1 parent e76cee9 commit 4c5ca93

File tree

1 file changed

+24
-26
lines changed

1 file changed

+24
-26
lines changed

src/whisper.cpp

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -869,8 +869,8 @@ struct whisper_aheads_masks {
869869
};
870870

871871
struct vad_time_mapping {
872-
double processed_time; // Time in processed (VAD) audio
873-
double original_time; // Corresponding time in original audio
872+
uint64_t processed_time; // Time in processed (VAD) audio
873+
uint64_t original_time; // Corresponding time in original audio
874874
};
875875

876876
struct whisper_state {
@@ -6716,8 +6716,8 @@ static bool whisper_vad(
67166716
segment.vad_end = (offset + segment_length) / (double)WHISPER_SAMPLE_RATE;
67176717

67186718
// Add segment boundaries to mapping table
6719-
vad_time_mapping start_mapping = {segment.vad_start, segment.orig_start};
6720-
vad_time_mapping end_mapping = {segment.vad_end, segment.orig_end};
6719+
vad_time_mapping start_mapping = {(uint64_t)(segment.vad_start * 100.0 + 0.5), (uint64_t)(segment.orig_start * 100.0 + 0.5)};
6720+
vad_time_mapping end_mapping = {(uint64_t)(segment.vad_end * 100.0 + 0.5), (uint64_t)(segment.orig_end * 100.0 + 0.5)};
67216721

67226722
state->vad_mapping_table.push_back(start_mapping);
67236723
state->vad_mapping_table.push_back(end_mapping);
@@ -6738,7 +6738,7 @@ static bool whisper_vad(
67386738
double proportion = (vad_time - segment.vad_start) / (segment.vad_end - segment.vad_start);
67396739
double orig_time = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
67406740

6741-
vad_time_mapping intermediate_mapping = {vad_time, orig_time};
6741+
vad_time_mapping intermediate_mapping = {(uint64_t)(vad_time * 100.0 + 0.5), (uint64_t)(orig_time * 100.0 + 0.5)};
67426742
state->vad_mapping_table.push_back(intermediate_mapping);
67436743
}
67446744
}
@@ -6762,8 +6762,8 @@ static bool whisper_vad(
67626762
double orig_silence_end = vad_segments->data[i+1].start;
67636763

67646764
// Add mapping points for silence boundaries
6765-
state->vad_mapping_table.push_back({silence_start_vad, orig_silence_start});
6766-
state->vad_mapping_table.push_back({silence_end_vad, orig_silence_end});
6765+
state->vad_mapping_table.push_back({(uint64_t)(silence_start_vad * 100.0 + 0.5),(uint64_t)(orig_silence_start * 100.0 + 0.5)});
6766+
state->vad_mapping_table.push_back({(uint64_t)(silence_end_vad * 100.0 + 0.5),(uint64_t)(orig_silence_end * 100.0 + 0.5)});
67676767

67686768
// Fill with zeros (silence)
67696769
memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
@@ -6783,7 +6783,7 @@ static bool whisper_vad(
67836783
if (!state->vad_mapping_table.empty()) {
67846784
auto last = std::unique(state->vad_mapping_table.begin(), state->vad_mapping_table.end(),
67856785
[](const vad_time_mapping& a, const vad_time_mapping& b) {
6786-
return std::abs(a.processed_time - b.processed_time) < 1e-9;
6786+
return a.processed_time == b.processed_time;
67876787
});
67886788
state->vad_mapping_table.erase(last, state->vad_mapping_table.end());
67896789
}
@@ -7873,7 +7873,7 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
78737873
return ctx->state->lang_id;
78747874
}
78757875

7876-
static double map_processed_to_original_time(double processed_time, const std::vector<vad_time_mapping>& mapping_table) {
7876+
static uint64_t map_processed_to_original_time(uint64_t processed_time, const std::vector<vad_time_mapping>& mapping_table) {
78777877
if (mapping_table.empty()) {
78787878
return processed_time;
78797879
}
@@ -7889,29 +7889,29 @@ static double map_processed_to_original_time(double processed_time, const std::v
78897889
// Binary search over the time map that finds the first entry that has a
78907890
// processed time greater than or equal to the current processed time.
78917891
auto upper = std::lower_bound(mapping_table.begin(), mapping_table.end(), processed_time,
7892-
[](const vad_time_mapping& entry, double time) {
7892+
[](const vad_time_mapping& entry, uint64_t time) {
78937893
return entry.processed_time < time;
78947894
}
78957895
);
78967896

78977897
// If exact match found
7898-
if (std::abs(upper->processed_time - processed_time) < 1e-9) {
7898+
if (upper->processed_time == processed_time) {
78997899
return upper->original_time;
79007900
}
79017901

79027902
// Need to interpolate between two points
79037903
auto lower = upper - 1;
79047904

7905-
// Calculate the proportion
7906-
double proportion = 0.0;
7907-
double denominator = upper->processed_time - lower->processed_time;
7905+
uint64_t processed_diff = upper->processed_time - lower->processed_time;
7906+
uint64_t original_diff = upper->original_time - lower->original_time;
7907+
uint64_t offset = processed_time - lower->processed_time;
79087908

7909-
if (denominator > 1e-9) { // Avoid division by very small numbers
7910-
proportion = (processed_time - lower->processed_time) / denominator;
7909+
if (processed_diff == 0) {
7910+
return lower->original_time;
79117911
}
79127912

79137913
// Perform linear interpolation
7914-
return lower->original_time + proportion * (upper->original_time - lower->original_time);
7914+
return lower->original_time + (offset * original_diff) / processed_diff;
79157915
}
79167916

79177917
// Function to get the starting timestamp of a segment
@@ -7923,12 +7923,10 @@ int64_t whisper_full_get_segment_t0_from_state(struct whisper_state* state, int
79237923
}
79247924

79257925
// Get the processed timestamp
7926-
double t0 = state->result_all[i_segment].t0 / 100.0;
7926+
uint64_t t0 = state->result_all[i_segment].t0;
79277927

79287928
// Map to original time using the mapping table
7929-
double orig_t0 = map_processed_to_original_time(t0, state->vad_mapping_table);
7930-
7931-
return (int64_t)(orig_t0 * 100 + 0.5); // Round to nearest
7929+
return map_processed_to_original_time(t0, state->vad_mapping_table);
79327930
}
79337931

79347932
// Function to get the ending timestamp of a segment
@@ -7940,21 +7938,21 @@ int64_t whisper_full_get_segment_t1_from_state(struct whisper_state* state, int
79407938
}
79417939

79427940
// Get the processed timestamp
7943-
double t1 = state->result_all[i_segment].t1 / 100.0;
7941+
uint64_t t1 = state->result_all[i_segment].t1;
79447942

79457943
// Map to original time using the mapping table
7946-
double orig_t1 = map_processed_to_original_time(t1, state->vad_mapping_table);
7944+
uint64_t orig_t1 = map_processed_to_original_time(t1, state->vad_mapping_table);
79477945

79487946
// Get the corresponding t0 for this segment
7949-
double orig_t0 = whisper_full_get_segment_t0_from_state(state, i_segment) / 100.0;
7947+
uint64_t orig_t0 = whisper_full_get_segment_t0_from_state(state, i_segment);
79507948

79517949
// Ensure minimum duration to prevent zero-length segments
7952-
const double min_duration = 0.01; // 10ms minimum
7950+
const uint64_t min_duration = 10; // 10ms minimum
79537951
if (orig_t1 - orig_t0 < min_duration) {
79547952
orig_t1 = orig_t0 + min_duration;
79557953
}
79567954

7957-
return (int64_t)(orig_t1 * 100 + 0.5); // Round to nearest
7955+
return orig_t1;
79587956
}
79597957

79607958

0 commit comments

Comments
 (0)