Skip to content

Commit 83d8ca8

Browse files
committed
feat: use vad
1 parent 1635590 commit 83d8ca8

File tree

9 files changed

+95
-122
lines changed

9 files changed

+95
-122
lines changed

Cargo.lock

Lines changed: 36 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,5 +193,7 @@ objc2-core-foundation = "0.3"
193193
objc2-foundation = "0.3"
194194
objc2-user-notifications = "0.3"
195195

196+
voice_activity_detector = "0.2"
197+
196198
[patch.crates-io]
197199
cpal = { git = "https://github.com/RustAudio/cpal", rev = "51c3b43" }

crates/chunker/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ hound = { workspace = true }
88
hypr-data = { workspace = true }
99

1010
[dependencies]
11-
hypr-vad = { workspace = true }
1211
kalosm-sound = { workspace = true, default-features = false }
1312
rodio = { workspace = true, features = ["wav"] }
1413

@@ -17,3 +16,4 @@ serde = { workspace = true }
1716
thiserror = { workspace = true }
1817
tokio = { workspace = true, features = ["rt-multi-thread", "macros"] }
1918
tracing = { workspace = true }
19+
voice_activity_detector = { workspace = true }

crates/chunker/src/error.rs

Lines changed: 0 additions & 16 deletions
This file was deleted.

crates/chunker/src/lib.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
1-
mod error;
2-
mod predictor;
31
mod stream;
42

5-
pub use error::*;
6-
pub use predictor::*;
73
pub use stream::*;
84

95
use kalosm_sound::AsyncSource;
106
use std::time::Duration;
7+
use voice_activity_detector::VoiceActivityDetector;
118

129
pub trait ChunkerExt: AsyncSource + Sized {
13-
fn chunks<P: Predictor + Unpin>(
14-
self,
15-
predictor: P,
16-
chunk_duration: Duration,
17-
) -> ChunkStream<Self, P>
10+
fn chunks(self, vad: VoiceActivityDetector, chunk_duration: Duration) -> ChunkStream<Self>
1811
where
1912
Self: Unpin,
2013
{
21-
ChunkStream::new(self, predictor, chunk_duration)
14+
ChunkStream::new(self, vad, chunk_duration)
2215
}
2316
}
2417

@@ -28,6 +21,7 @@ impl<T: AsyncSource> ChunkerExt for T {}
2821
mod tests {
2922
use super::*;
3023
use futures_util::StreamExt;
24+
use voice_activity_detector::VoiceActivityDetector;
3125

3226
#[tokio::test]
3327
async fn test_chunker() {
@@ -43,7 +37,12 @@ mod tests {
4337
sample_format: hound::SampleFormat::Float,
4438
};
4539

46-
let mut stream = audio_source.chunks(RMS::new(), Duration::from_secs(15));
40+
let vad = VoiceActivityDetector::builder()
41+
.sample_rate(16000)
42+
.chunk_size(512usize)
43+
.build()
44+
.unwrap();
45+
let mut stream = audio_source.chunks(vad, Duration::from_secs(15));
4746
let mut i = 0;
4847

4948
std::fs::remove_dir_all("tmp/english_1").unwrap();

crates/chunker/src/predictor.rs

Lines changed: 0 additions & 45 deletions
This file was deleted.

crates/chunker/src/stream.rs

Lines changed: 34 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,20 @@ use std::{
77

88
use kalosm_sound::AsyncSource;
99
use rodio::buffer::SamplesBuffer;
10+
use voice_activity_detector::{IteratorExt, VoiceActivityDetector};
1011

11-
use crate::Predictor;
12-
13-
pub struct ChunkStream<S: AsyncSource + Unpin, P: Predictor + Unpin> {
12+
pub struct ChunkStream<S: AsyncSource + Unpin> {
1413
source: S,
15-
predictor: P,
14+
vad: VoiceActivityDetector,
1615
buffer: Vec<f32>,
1716
max_duration: Duration,
1817
}
1918

20-
impl<S: AsyncSource + Unpin, P: Predictor + Unpin> ChunkStream<S, P> {
21-
pub fn new(source: S, predictor: P, max_duration: Duration) -> Self {
19+
impl<S: AsyncSource + Unpin> ChunkStream<S> {
20+
pub fn new(source: S, vad: VoiceActivityDetector, max_duration: Duration) -> Self {
2221
Self {
2322
source,
24-
predictor,
23+
vad,
2524
buffer: Vec::new(),
2625
max_duration,
2726
}
@@ -34,26 +33,9 @@ impl<S: AsyncSource + Unpin, P: Predictor + Unpin> ChunkStream<S, P> {
3433
fn samples_for_duration(&self, duration: Duration) -> usize {
3534
(self.source.sample_rate() as f64 * duration.as_secs_f64()) as usize
3635
}
37-
38-
fn trim_silence(predictor: &P, data: &mut Vec<f32>) {
39-
const WINDOW_SIZE: usize = 100;
40-
41-
let mut trim_index = 0;
42-
for start_idx in (0..data.len()).step_by(WINDOW_SIZE) {
43-
let end_idx = (start_idx + WINDOW_SIZE).min(data.len());
44-
let window = &data[start_idx..end_idx];
45-
46-
if let Ok(false) = predictor.predict(window) {
47-
trim_index = start_idx;
48-
break;
49-
}
50-
}
51-
52-
data.drain(0..trim_index);
53-
}
5436
}
5537

56-
impl<S: AsyncSource + Unpin, P: Predictor + Unpin> Stream for ChunkStream<S, P> {
38+
impl<S: AsyncSource + Unpin> Stream for ChunkStream<S> {
5739
type Item = SamplesBuffer<f32>;
5840

5941
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
@@ -62,8 +44,6 @@ impl<S: AsyncSource + Unpin, P: Predictor + Unpin> Stream for ChunkStream<S, P>
6244
let sample_rate = this.source.sample_rate();
6345

6446
let min_buffer_samples = this.samples_for_duration(Duration::from_secs(6));
65-
let silence_window_samples = this.samples_for_duration(Duration::from_millis(500));
66-
6747
let stream = this.source.as_stream();
6848
let mut stream = std::pin::pin!(stream);
6949

@@ -73,32 +53,41 @@ impl<S: AsyncSource + Unpin, P: Predictor + Unpin> Stream for ChunkStream<S, P>
7353
this.buffer.push(sample);
7454

7555
if this.buffer.len() >= min_buffer_samples {
76-
let buffer_len = this.buffer.len();
77-
let silence_start = buffer_len.saturating_sub(silence_window_samples);
78-
let last_samples = &this.buffer[silence_start..buffer_len];
79-
80-
if let Ok(false) = this.predictor.predict(last_samples) {
81-
let mut data = std::mem::take(&mut this.buffer);
82-
Self::trim_silence(&this.predictor, &mut data);
83-
84-
return Poll::Ready(Some(SamplesBuffer::new(1, sample_rate, data)));
85-
}
56+
let data = std::mem::take(&mut this.buffer);
57+
let speech = filter_speech_chunks(&mut this.vad, data);
58+
return Poll::Ready(Some(SamplesBuffer::new(1, sample_rate, speech)));
8659
}
8760
}
8861
Poll::Ready(None) if !this.buffer.is_empty() => {
89-
let mut data = std::mem::take(&mut this.buffer);
90-
Self::trim_silence(&this.predictor, &mut data);
91-
92-
return Poll::Ready(Some(SamplesBuffer::new(1, sample_rate, data)));
62+
let data = std::mem::take(&mut this.buffer);
63+
let speech = filter_speech_chunks(&mut this.vad, data);
64+
return Poll::Ready(Some(SamplesBuffer::new(1, sample_rate, speech)));
9365
}
9466
Poll::Ready(None) => return Poll::Ready(None),
9567
Poll::Pending => return Poll::Pending,
9668
}
9769
}
9870

99-
let mut chunk: Vec<_> = this.buffer.drain(0..max_samples).collect();
100-
Self::trim_silence(&this.predictor, &mut chunk);
101-
102-
Poll::Ready(Some(SamplesBuffer::new(1, sample_rate, chunk)))
71+
let data = this.buffer.drain(0..max_samples);
72+
let speech = filter_speech_chunks(&mut this.vad, data);
73+
Poll::Ready(Some(SamplesBuffer::new(1, sample_rate, speech)))
10374
}
10475
}
76+
77+
// helper function to filter speech chunks
78+
fn filter_speech_chunks<D: IntoIterator<Item = f32>>(
79+
vad: &mut VoiceActivityDetector,
80+
data: D,
81+
) -> Vec<f32> {
82+
data.into_iter()
83+
.label(vad, 0.75, 3)
84+
.filter_map(|label| {
85+
if label.is_speech() {
86+
Some(label.into_iter())
87+
} else {
88+
None
89+
}
90+
})
91+
.flatten()
92+
.collect()
93+
}

plugins/local-stt/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ futures-util = { workspace = true }
5050
tokio = { workspace = true, features = ["rt", "macros"] }
5151
tokio-util = { workspace = true }
5252

53+
voice_activity_detector = { workspace = true }
54+
5355
[target.'cfg(not(target_os = "macos"))'.dependencies]
5456
kalosm-sound = { workspace = true, default-features = false }
5557

plugins/local-stt/src/server.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ use hypr_chunker::ChunkerExt;
2121
use hypr_listener_interface::{ListenOutputChunk, ListenParams, Word};
2222
use hypr_ws_utils::WebSocketAudioSource;
2323

24+
use voice_activity_detector::VoiceActivityDetector;
25+
2426
use crate::manager::{ConnectionGuard, ConnectionManager};
2527

2628
#[derive(Default)]
@@ -144,9 +146,14 @@ async fn websocket_with_model(
144146
async fn websocket(socket: WebSocket, model: hypr_whisper::local::Whisper, guard: ConnectionGuard) {
145147
let (mut ws_sender, ws_receiver) = socket.split();
146148
let mut stream = {
147-
let audio_source = WebSocketAudioSource::new(ws_receiver, 16 * 1000);
148-
let chunked =
149-
audio_source.chunks(hypr_chunker::RMS::new(), std::time::Duration::from_secs(15));
149+
let sample_rate = 16_000;
150+
let audio_source = WebSocketAudioSource::new(ws_receiver, sample_rate);
151+
let vad = VoiceActivityDetector::builder()
152+
.sample_rate(sample_rate)
153+
.chunk_size(512usize)
154+
.build()
155+
.expect("vad config is valid");
156+
let chunked = audio_source.chunks(vad, std::time::Duration::from_secs(15));
150157
hypr_whisper::local::TranscribeChunkedAudioStreamExt::transcribe(chunked, model)
151158
};
152159

0 commit comments

Comments
 (0)