Skip to content

Commit 82d6980

Browse files
committed
vad : revisit timestamp alignment/mapping
This commit improving the timestamp alignment by introducing a mapping table, adding intermediate reference points for longer segments, and binary search for lookups. The motivation for this changes is to address issues with the currently solution where zero-length segments are possible, and also to improve the precision of the VAD timestamps. Refs: ggml-org#3162
1 parent 2c4b904 commit 82d6980

File tree

1 file changed

+147
-110
lines changed

1 file changed

+147
-110
lines changed

src/whisper.cpp

Lines changed: 147 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,11 @@ struct whisper_aheads_masks {
868868
ggml_backend_buffer_t buffer = nullptr;
869869
};
870870

871+
struct vad_time_mapping {
872+
double processed_time; // Time in processed (VAD) audio
873+
double original_time; // Corresponding time in original audio
874+
};
875+
871876
struct whisper_state {
872877
int64_t t_sample_us = 0;
873878
int64_t t_encode_us = 0;
@@ -957,13 +962,16 @@ struct whisper_state {
957962
whisper_vad_context * vad_context = nullptr;
958963

959964
struct vad_segment_info {
960-
float orig_start;
961-
float orig_end;
962-
float vad_start;
963-
float vad_end;
965+
double orig_start;
966+
double orig_end;
967+
double vad_start;
968+
double vad_end;
964969
};
965970
std::vector<vad_segment_info> vad_segments;
966971
bool has_vad_segments = false;
972+
973+
std::vector<vad_time_mapping> vad_mapping_table;
974+
bool vad_mapping_table_initialized = false;
967975
};
968976

969977
struct whisper_context {
@@ -4420,8 +4428,8 @@ struct whisper_vad_model {
44204428
};
44214429

44224430
struct whisper_vad_segment {
4423-
float start; // Start time in seconds
4424-
float end; // End time in seconds
4431+
double start; // Start time in seconds
4432+
double end; // End time in seconds
44254433
};
44264434

44274435
struct whisper_vad_segments {
@@ -6617,9 +6625,13 @@ static bool whisper_vad(
66176625
int n_samples,
66186626
std::vector<float> & filtered_samples,
66196627
int & filtered_n_samples) {
6620-
WHISPER_LOG_INFO("%s: VAD is enabled, processing speach segments only\n", __func__);
6628+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
66216629
filtered_n_samples = 0;
66226630

6631+
// Clear any existing mapping table
6632+
state->vad_mapping_table.clear();
6633+
state->vad_mapping_table_initialized = false;
6634+
66236635
if (state->vad_context == nullptr) {
66246636
struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
66256637
struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
@@ -6640,6 +6652,11 @@ static bool whisper_vad(
66406652
ctx->state->vad_segments.clear();
66416653
ctx->state->vad_segments.reserve(vad_segments->data.size());
66426654

6655+
// Initialize the time mapping table
6656+
state->vad_mapping_table.clear();
6657+
state->vad_mapping_table.reserve(vad_segments->data.size() * 4);
6658+
state->vad_mapping_table_initialized = true;
6659+
66436660
WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
66446661
float overlap_seconds = vad_params.samples_overlap;
66456662
int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
@@ -6689,15 +6706,42 @@ static bool whisper_vad(
66896706
segment_start_samples = std::min(segment_start_samples, n_samples - 1);
66906707
segment_end_samples = std::min(segment_end_samples, n_samples);
66916708
int segment_length = segment_end_samples - segment_start_samples;
6692-
66936709
if (segment_length > 0) {
66946710
whisper_state::vad_segment_info segment;
66956711

66966712
segment.orig_start = vad_segments->data[i].start;
66976713
segment.orig_end = vad_segments->data[i].end;
66986714

6699-
segment.vad_start = offset / (float)WHISPER_SAMPLE_RATE;
6700-
segment.vad_end = (offset + segment_length) / (float)WHISPER_SAMPLE_RATE;
6715+
segment.vad_start = offset / (double)WHISPER_SAMPLE_RATE;
6716+
segment.vad_end = (offset + segment_length) / (double)WHISPER_SAMPLE_RATE;
6717+
6718+
// Add segment boundaries to mapping table
6719+
vad_time_mapping start_mapping = {segment.vad_start, segment.orig_start};
6720+
vad_time_mapping end_mapping = {segment.vad_end, segment.orig_end};
6721+
6722+
state->vad_mapping_table.push_back(start_mapping);
6723+
state->vad_mapping_table.push_back(end_mapping);
6724+
6725+
// Add intermediate points for longer segments to improve interpolation accuracy
6726+
const double min_segment_length = 1.0; // 1 second
6727+
const double point_interval = 0.2; // Add a point every 200ms
6728+
6729+
if (segment.vad_end - segment.vad_start > min_segment_length) {
6730+
double segment_duration = segment.vad_end - segment.vad_start;
6731+
int num_points = (int)(segment_duration / point_interval) - 1;
6732+
6733+
for (int j = 1; j <= num_points; j++) {
6734+
double vad_time = segment.vad_start + j * point_interval;
6735+
6736+
if (vad_time >= segment.vad_end) continue;
6737+
6738+
double proportion = (vad_time - segment.vad_start) / (segment.vad_end - segment.vad_start);
6739+
double orig_time = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
6740+
6741+
vad_time_mapping intermediate_mapping = {vad_time, orig_time};
6742+
state->vad_mapping_table.push_back(intermediate_mapping);
6743+
}
6744+
}
67016745

67026746
WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
67036747
__func__, segment.orig_start, segment.orig_end, segment.vad_start, segment.vad_end);
@@ -6709,13 +6753,43 @@ static bool whisper_vad(
67096753

67106754
// Add silence after this segment (except after the last segment)
67116755
if (i < (int)vad_segments->data.size() - 1) {
6756+
// Calculate the start and end time of the silence gap in processed audio
6757+
double silence_start_vad = offset / (double)WHISPER_SAMPLE_RATE;
6758+
double silence_end_vad = (offset + silence_samples) / (double)WHISPER_SAMPLE_RATE;
6759+
6760+
// Calculate the corresponding original times
6761+
double orig_silence_start = segment.orig_end;
6762+
double orig_silence_end = vad_segments->data[i+1].start;
6763+
6764+
// Add mapping points for silence boundaries
6765+
state->vad_mapping_table.push_back({silence_start_vad, orig_silence_start});
6766+
state->vad_mapping_table.push_back({silence_end_vad, orig_silence_end});
6767+
67126768
// Fill with zeros (silence)
67136769
memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
67146770
offset += silence_samples;
67156771
}
67166772
}
67176773
}
67186774

6775+
// Sort the mapping table by processed time
6776+
std::sort(state->vad_mapping_table.begin(), state->vad_mapping_table.end(),
6777+
[](const vad_time_mapping& a, const vad_time_mapping& b) {
6778+
return a.processed_time < b.processed_time;
6779+
});
6780+
6781+
// Remove any duplicate processed times to ensure monotonicity which is
6782+
// needed for binary search and interpolation later.
6783+
if (!state->vad_mapping_table.empty()) {
6784+
auto last = std::unique(state->vad_mapping_table.begin(), state->vad_mapping_table.end(),
6785+
[](const vad_time_mapping& a, const vad_time_mapping& b) {
6786+
return std::abs(a.processed_time - b.processed_time) < 1e-9;
6787+
});
6788+
state->vad_mapping_table.erase(last, state->vad_mapping_table.end());
6789+
}
6790+
6791+
WHISPER_LOG_INFO("%s: Created time mapping table with %d points\n", __func__, (int)state->vad_mapping_table.size());
6792+
67196793
filtered_n_samples = offset;
67206794
WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
67216795
__func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
@@ -7799,130 +7873,93 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
77997873
return ctx->state->lang_id;
78007874
}
78017875

7802-
int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
7803-
// If VAD wasn't used, return the original timestamp
7804-
if (!state->has_vad_segments || state->vad_segments.empty()) {
7805-
return state->result_all[i_segment].t0;
7876+
static double map_processed_to_original_time(double processed_time, const std::vector<vad_time_mapping>& mapping_table) {
7877+
if (mapping_table.empty()) {
7878+
return processed_time;
78067879
}
78077880

7808-
// Get the start timestamp produced by whisper_full. whisper_full processes
7809-
// only the speech segments in this case so we need to map these timestamps
7810-
// back to the original audio.
7811-
float t0 = state->result_all[i_segment].t0 / 100.0f;
7881+
if (processed_time <= mapping_table.front().processed_time) {
7882+
return mapping_table.front().original_time; // Before first mapping point
7883+
}
78127884

7813-
// Find which VAD segment this timestamp belongs.
7814-
// TODO(danbev) This could be optimized by using a binary search if the number
7815-
// of segments exceed a certain limit. Also we might be able to assume that
7816-
// the access pattern is sequential and optimized for that too.
7817-
for (size_t i = 0; i < state->vad_segments.size(); i++) {
7818-
const auto & segment = state->vad_segments[i];
7885+
if (processed_time >= mapping_table.back().processed_time) {
7886+
return mapping_table.back().original_time; // After last mapping point
7887+
}
78197888

7820-
// Check if the timestamp falls within this segment.
7821-
if (t0 >= segment.vad_start && t0 <= segment.vad_end) {
7822-
float proportion = 0.0f;
7823-
if (segment.vad_end > segment.vad_start) {
7824-
proportion = (t0 - segment.vad_start) / (segment.vad_end - segment.vad_start);
7825-
}
7826-
float orig_t0 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
7827-
return (int64_t)(orig_t0 * 100);
7889+
// Binary search over the time map that finds the first entry that has a
7890+
// processed time greater than or equal to the current processed time.
7891+
auto upper = std::lower_bound(mapping_table.begin(), mapping_table.end(), processed_time,
7892+
[](const vad_time_mapping& entry, double time) {
7893+
return entry.processed_time < time;
78287894
}
7895+
);
7896+
7897+
// If exact match found
7898+
if (std::abs(upper->processed_time - processed_time) < 1e-9) {
7899+
return upper->original_time;
78297900
}
78307901

7831-
// Check if the timestamp falls between two segments.
7832-
for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
7833-
const auto & curr = state->vad_segments[i];
7834-
const auto & next = state->vad_segments[i + 1];
7902+
// Need to interpolate between two points
7903+
auto lower = upper - 1;
78357904

7836-
if (t0 > curr.vad_end && t0 < next.vad_start) {
7837-
// Calculate how far we are through the gap as a proportion
7838-
float gap_proportion = 0.0f;
7839-
if (next.vad_start > curr.vad_end) {
7840-
gap_proportion = (t0 - curr.vad_end) / (next.vad_start - curr.vad_end);
7841-
}
7842-
float orig_t0 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
7843-
return (int64_t)(orig_t0 * 100);
7844-
}
7845-
}
7905+
// Calculate the proportion
7906+
double proportion = 0.0;
7907+
double denominator = upper->processed_time - lower->processed_time;
78467908

7847-
// Handle the case where the timestamp is after the last segment.
7848-
if (t0 > state->vad_segments.back().vad_end) {
7849-
// For timestamps after the last segment, add the extra time to the end of the last segment
7850-
const auto& last = state->vad_segments.back();
7851-
// Calculate how far beyond the last segment
7852-
float extra_time = t0 - last.vad_end;
7853-
// Add this extra time to the original end time
7854-
float orig_t0 = last.orig_end + extra_time;
7855-
return (int64_t)(orig_t0 * 100);
7909+
if (denominator > 1e-9) { // Avoid division by very small numbers
7910+
proportion = (processed_time - lower->processed_time) / denominator;
78567911
}
78577912

7858-
WHISPER_LOG_WARN("%s: Could not map t0 = %f to a VAD segment\n", __func__, t0);
7859-
return t0;
7913+
// Perform linear interpolation
7914+
return lower->original_time + proportion * (upper->original_time - lower->original_time);
78607915
}
78617916

7862-
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
7863-
return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
7917+
// Function to get the starting timestamp of a segment
7918+
int64_t whisper_full_get_segment_t0_from_state(struct whisper_state* state, int i_segment) {
7919+
// If VAD wasn't used, return the original timestamp
7920+
if (!state->has_vad_segments || !state->vad_mapping_table_initialized ||
7921+
state->vad_mapping_table.empty()) {
7922+
return state->result_all[i_segment].t0;
7923+
}
7924+
7925+
// Get the processed timestamp
7926+
double t0 = state->result_all[i_segment].t0 / 100.0;
7927+
7928+
// Map to original time using the mapping table
7929+
double orig_t0 = map_processed_to_original_time(t0, state->vad_mapping_table);
7930+
7931+
return (int64_t)(orig_t0 * 100 + 0.5); // Round to nearest
78647932
}
78657933

7866-
int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
7934+
// Function to get the ending timestamp of a segment
7935+
int64_t whisper_full_get_segment_t1_from_state(struct whisper_state* state, int i_segment) {
78677936
// If VAD wasn't used, return the original timestamp
7868-
if (!state->has_vad_segments || state->vad_segments.empty()) {
7937+
if (!state->has_vad_segments || !state->vad_mapping_table_initialized ||
7938+
state->vad_mapping_table.empty()) {
78697939
return state->result_all[i_segment].t1;
78707940
}
78717941

7872-
// Get the end timestamp produced by whisper_full. whisper_full processes
7873-
// only the speech segments in this case so we need to map these timestamps
7874-
// back to the original audio.
7875-
float t1 = state->result_all[i_segment].t1 / 100.0f;
7876-
7877-
// Find which VAD segment this timestamp belongs.
7878-
// TODO(danbev) This could be optimized by using a binary search if the number
7879-
// of segments exceed a certain limit. Also we might be able to assume that
7880-
// the access pattern is sequential and optimized for that too.
7881-
for (size_t i = 0; i < state->vad_segments.size(); i++) {
7882-
const auto& segment = state->vad_segments[i];
7883-
7884-
// Check if the timestamp falls within this segment.
7885-
if (t1 >= segment.vad_start && t1 <= segment.vad_end) {
7886-
// Calculate the proportion through the filtered segment.
7887-
float proportion = 0.0f;
7888-
if (segment.vad_end > segment.vad_start) {
7889-
proportion = (t1 - segment.vad_start) / (segment.vad_end - segment.vad_start);
7890-
}
7891-
float orig_t1 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
7892-
return (int64_t)(orig_t1 * 100);
7893-
}
7894-
}
7942+
// Get the processed timestamp
7943+
double t1 = state->result_all[i_segment].t1 / 100.0;
78957944

7896-
// Check if the timestamp falls between two segments.
7897-
for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
7898-
const auto & curr = state->vad_segments[i];
7899-
const auto & next = state->vad_segments[i + 1];
7945+
// Map to original time using the mapping table
7946+
double orig_t1 = map_processed_to_original_time(t1, state->vad_mapping_table);
79007947

7901-
if (t1 > curr.vad_end && t1 < next.vad_start) {
7902-
// Calculate how far we are through the gap as a proportion
7903-
float gap_proportion = 0.0f;
7904-
if (next.vad_start > curr.vad_end) {
7905-
gap_proportion = (t1 - curr.vad_end) / (next.vad_start - curr.vad_end);
7906-
}
7907-
// Map to the corresponding position in the original gap
7908-
float orig_t1 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
7909-
return (int64_t)(orig_t1 * 100);
7910-
}
7911-
}
7948+
// Get the corresponding t0 for this segment
7949+
double orig_t0 = whisper_full_get_segment_t0_from_state(state, i_segment) / 100.0;
79127950

7913-
// Handle the case where the timestamp is after the last segment
7914-
if (t1 > state->vad_segments.back().vad_end) {
7915-
// For the last segment, use the end of the last VAD segment
7916-
const auto& last = state->vad_segments.back();
7917-
// Calculate how far beyond the last segment
7918-
float extra_time = t1 - last.vad_end;
7919-
// Add this extra time to the original end time
7920-
float orig_t1 = last.orig_end + extra_time;
7921-
return (int64_t)(orig_t1 * 100);
7951+
// Ensure minimum duration to prevent zero-length segments
7952+
const double min_duration = 0.01; // 10ms minimum
7953+
if (orig_t1 - orig_t0 < min_duration) {
7954+
orig_t1 = orig_t0 + min_duration;
79227955
}
79237956

7924-
WHISPER_LOG_WARN("%s: Could not map t1 = %f to a VAD segment\n", __func__, t1);
7925-
return t1;
7957+
return (int64_t)(orig_t1 * 100 + 0.5); // Round to nearest
7958+
}
7959+
7960+
7961+
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
7962+
return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
79267963
}
79277964

79287965
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {

0 commit comments

Comments
 (0)