Skip to content

Commit 9bf2a02

Browse files
authored
refactor(vad): unify VAD logic with StreamingVad (#2156)
- Introduce StreamingVad in vad-ext to centralize VAD processing - Refactor VadAgc to use StreamingVad with optional masking - Refactor ContinuousVadMaskStream to delegate to StreamingVad - Remove double VAD in mic path: use VadAgc::with_masking(true) instead - Simplify dependencies: agc now only depends on vad-ext
1 parent c9cffaa commit 9bf2a02

File tree

7 files changed

+219
-193
lines changed

7 files changed

+219
-193
lines changed

Cargo.lock

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/agc/Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ rodio = { workspace = true }
1010

1111
[dependencies]
1212
hypr-audio-utils = { workspace = true }
13-
hypr-vad3 = { workspace = true }
14-
hypr-vvad = { workspace = true }
13+
hypr-vad-ext = { workspace = true }
1514

1615
dagc = "0.1.1"

crates/agc/src/lib.rs

Lines changed: 29 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,52 @@
11
use dagc::MonoAgc;
2-
use hypr_audio_utils::f32_to_i16_samples;
3-
use hypr_vvad::VoiceActivityDetector;
2+
use hypr_vad_ext::{StreamingVad, VadConfig};
43

54
pub struct VadAgc {
65
agc: MonoAgc,
7-
vad: VoiceActivityDetector,
8-
frame_size: usize,
9-
vad_tail: Vec<f32>,
10-
last_is_speech: bool,
6+
vad: Option<StreamingVad>,
7+
vad_cfg: VadConfig,
8+
mask_non_speech: bool,
119
}
1210

1311
impl VadAgc {
1412
pub fn new(desired_output_rms: f32, distortion_factor: f32) -> Self {
1513
Self {
1614
agc: MonoAgc::new(desired_output_rms, distortion_factor).expect("failed_to_create_agc"),
17-
vad: VoiceActivityDetector::new(),
18-
frame_size: 0,
19-
vad_tail: Vec::new(),
20-
last_is_speech: true,
15+
vad: None,
16+
vad_cfg: VadConfig::default(),
17+
mask_non_speech: false,
2118
}
2219
}
2320

21+
pub fn with_masking(mut self, mask_non_speech: bool) -> Self {
22+
self.mask_non_speech = mask_non_speech;
23+
self
24+
}
25+
26+
pub fn with_vad_config(mut self, cfg: VadConfig) -> Self {
27+
self.vad_cfg = cfg;
28+
self
29+
}
30+
2431
pub fn process(&mut self, samples: &mut [f32]) {
2532
if samples.is_empty() {
2633
return;
2734
}
2835

29-
if self.frame_size == 0 {
30-
self.frame_size = hypr_vad3::choose_optimal_frame_size(samples.len());
31-
}
32-
let frame_size = self.frame_size;
33-
34-
let mut pos = 0;
36+
let vad = self
37+
.vad
38+
.get_or_insert_with(|| StreamingVad::with_config(samples.len(), self.vad_cfg.clone()));
3539

36-
if !self.vad_tail.is_empty() {
37-
let needed = frame_size - self.vad_tail.len();
38-
let to_take = needed.min(samples.len());
40+
let agc = &mut self.agc;
41+
let mask_non_speech = self.mask_non_speech;
3942

40-
let mut frame_f32 = std::mem::take(&mut self.vad_tail);
41-
frame_f32.reserve(frame_size - frame_f32.len());
42-
frame_f32.extend_from_slice(&samples[..to_take]);
43-
44-
if frame_f32.len() == frame_size {
45-
let i16_samples = f32_to_i16_samples(&frame_f32);
46-
let is_speech = self.vad.predict_16khz(&i16_samples).unwrap_or(true);
47-
self.last_is_speech = is_speech;
48-
49-
self.agc.freeze_gain(!is_speech);
50-
self.agc.process(&mut samples[..to_take]);
51-
52-
pos = to_take;
53-
} else {
54-
self.vad_tail = frame_f32;
55-
56-
self.agc.freeze_gain(!self.last_is_speech);
57-
self.agc.process(samples);
58-
return;
43+
vad.process_in_place(samples, |frame, is_speech| {
44+
agc.freeze_gain(!is_speech);
45+
if !is_speech && mask_non_speech {
46+
frame.fill(0.0);
5947
}
60-
}
61-
62-
while samples.len() - pos >= frame_size {
63-
let frame = &mut samples[pos..pos + frame_size];
64-
65-
let i16_samples = f32_to_i16_samples(frame);
66-
let is_speech = self.vad.predict_16khz(&i16_samples).unwrap_or(true);
67-
self.last_is_speech = is_speech;
68-
69-
self.agc.freeze_gain(!is_speech);
70-
self.agc.process(frame);
71-
72-
pos += frame_size;
73-
}
74-
75-
if pos < samples.len() {
76-
self.vad_tail.clear();
77-
self.vad_tail.extend_from_slice(&samples[pos..]);
78-
79-
self.agc.freeze_gain(!self.last_is_speech);
80-
self.agc.process(&mut samples[pos..]);
81-
}
48+
agc.process(frame);
49+
});
8250
}
8351

8452
pub fn gain(&self) -> f32 {
@@ -88,13 +56,7 @@ impl VadAgc {
8856

8957
impl Default for VadAgc {
9058
fn default() -> Self {
91-
Self {
92-
agc: MonoAgc::new(0.03, 0.0001).expect("failed_to_create_agc"),
93-
vad: VoiceActivityDetector::new(),
94-
frame_size: 0,
95-
vad_tail: Vec::new(),
96-
last_is_speech: true,
97-
}
59+
Self::new(0.03, 0.0001)
9860
}
9961
}
10062

crates/vad-ext/src/continuous2.rs

Lines changed: 17 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,31 @@ use std::{
44
};
55

66
use futures_util::Stream;
7-
use hypr_audio_utils::f32_to_i16_samples;
8-
use hypr_vvad::VoiceActivityDetector;
7+
8+
use crate::{StreamingVad, VadConfig};
99

1010
pub struct ContinuousVadMaskStream<S> {
1111
inner: S,
12-
vad: VoiceActivityDetector,
13-
hangover_frames: usize,
14-
trailing_non_speech: usize,
15-
in_speech: bool,
16-
scratch_frame: Vec<f32>,
17-
amplitude_floor: f32,
12+
vad: Option<StreamingVad>,
13+
cfg: VadConfig,
1814
}
1915

2016
impl<S> ContinuousVadMaskStream<S> {
2117
pub fn new(inner: S) -> Self {
2218
Self {
2319
inner,
24-
vad: VoiceActivityDetector::new(),
25-
hangover_frames: 3,
26-
trailing_non_speech: 0,
27-
in_speech: true,
28-
scratch_frame: Vec::new(),
29-
amplitude_floor: 0.001,
20+
vad: None,
21+
cfg: VadConfig::default(),
3022
}
3123
}
3224

3325
pub fn with_hangover_frames(mut self, frames: usize) -> Self {
34-
self.hangover_frames = frames;
26+
self.cfg.hangover_frames = frames;
27+
self
28+
}
29+
30+
pub fn with_amplitude_floor(mut self, floor: f32) -> Self {
31+
self.cfg.amplitude_floor = floor;
3532
self
3633
}
3734

@@ -40,68 +37,15 @@ impl<S> ContinuousVadMaskStream<S> {
4037
return;
4138
}
4239

43-
let frame_size = hypr_vad3::choose_optimal_frame_size(chunk.len());
44-
debug_assert!(frame_size > 0, "VAD frame size must be > 0");
45-
46-
for frame in chunk.chunks_mut(frame_size) {
47-
self.process_frame(frame, frame_size);
48-
}
49-
}
50-
51-
fn smooth_vad_decision(&mut self, raw_is_speech: bool) -> bool {
52-
if raw_is_speech {
53-
self.in_speech = true;
54-
self.trailing_non_speech = 0;
55-
true
56-
} else if self.in_speech && self.trailing_non_speech < self.hangover_frames {
57-
self.trailing_non_speech += 1;
58-
true
59-
} else {
60-
self.in_speech = false;
61-
self.trailing_non_speech = 0;
62-
false
63-
}
64-
}
65-
66-
fn process_frame(&mut self, frame: &mut [f32], frame_size: usize) {
67-
if frame.is_empty() {
68-
return;
69-
}
40+
let vad = self
41+
.vad
42+
.get_or_insert_with(|| StreamingVad::with_config(chunk.len(), self.cfg.clone()));
7043

71-
let rms = Self::calculate_rms(frame);
72-
if rms < self.amplitude_floor {
73-
let is_speech = self.smooth_vad_decision(false);
44+
vad.process_in_place(chunk, |frame, is_speech| {
7445
if !is_speech {
7546
frame.fill(0.0);
7647
}
77-
return;
78-
}
79-
80-
let raw_is_speech = if frame.len() == frame_size {
81-
let i16_samples = f32_to_i16_samples(frame);
82-
self.vad.predict_16khz(&i16_samples).unwrap_or(true)
83-
} else {
84-
self.scratch_frame.clear();
85-
self.scratch_frame.extend_from_slice(frame);
86-
self.scratch_frame.resize(frame_size, 0.0);
87-
88-
let i16_samples = f32_to_i16_samples(&self.scratch_frame);
89-
self.vad.predict_16khz(&i16_samples).unwrap_or(true)
90-
};
91-
92-
let is_speech = self.smooth_vad_decision(raw_is_speech);
93-
94-
if !is_speech {
95-
frame.fill(0.0);
96-
}
97-
}
98-
99-
fn calculate_rms(samples: &[f32]) -> f32 {
100-
if samples.is_empty() {
101-
return 0.0;
102-
}
103-
let sum_sq: f32 = samples.iter().map(|&s| s * s).sum();
104-
(sum_sq / samples.len() as f32).sqrt()
48+
});
10549
}
10650
}
10751

@@ -194,8 +138,6 @@ mod tests {
194138
}
195139
}
196140

197-
// We should not *introduce* any non-zero samples, and the vast majority
198-
// of silence should stay zero.
199141
let non_zero_count = masked_samples.iter().filter(|&&s| s != 0.0).count();
200142
assert!(
201143
non_zero_count < 100,
@@ -204,48 +146,6 @@ mod tests {
204146
);
205147
}
206148

207-
#[test]
208-
fn test_hangover_logic() {
209-
// Use an empty inner stream; we only care about the internal state machine.
210-
let mut vad_stream = ContinuousVadMaskStream::new(stream::empty::<Result<Vec<f32>, ()>>());
211-
vad_stream.hangover_frames = 3;
212-
213-
// Initial state is conservative: in_speech = true
214-
assert!(vad_stream.in_speech);
215-
assert_eq!(vad_stream.trailing_non_speech, 0);
216-
217-
// Simulate raw VAD decisions: T, F, F, F, F
218-
// First: raw speech
219-
assert!(vad_stream.smooth_vad_decision(true));
220-
assert!(vad_stream.in_speech);
221-
assert_eq!(vad_stream.trailing_non_speech, 0);
222-
223-
// First false: still treated as speech (hangover 1/3)
224-
assert!(vad_stream.smooth_vad_decision(false));
225-
assert!(vad_stream.in_speech);
226-
assert_eq!(vad_stream.trailing_non_speech, 1);
227-
228-
// Second false: still speech (hangover 2/3)
229-
assert!(vad_stream.smooth_vad_decision(false));
230-
assert!(vad_stream.in_speech);
231-
assert_eq!(vad_stream.trailing_non_speech, 2);
232-
233-
// Third false: still speech (hangover 3/3)
234-
assert!(vad_stream.smooth_vad_decision(false));
235-
assert!(vad_stream.in_speech);
236-
assert_eq!(vad_stream.trailing_non_speech, 3);
237-
238-
// Fourth false: now we finally flip to non-speech
239-
assert!(!vad_stream.smooth_vad_decision(false));
240-
assert!(!vad_stream.in_speech);
241-
assert_eq!(vad_stream.trailing_non_speech, 0);
242-
243-
// More false: stays non-speech
244-
assert!(!vad_stream.smooth_vad_decision(false));
245-
assert!(!vad_stream.in_speech);
246-
assert_eq!(vad_stream.trailing_non_speech, 0);
247-
}
248-
249149
#[tokio::test]
250150
async fn test_different_chunk_sizes() {
251151
let input_audio = rodio::Decoder::new(std::io::BufReader::new(
@@ -316,7 +216,6 @@ mod tests {
316216

317217
#[test]
318218
fn test_frame_size_selection() {
319-
// Sanity-check assumptions about the VAD helper we're using.
320219
assert_eq!(hypr_vad3::choose_optimal_frame_size(160), 160);
321220
assert_eq!(hypr_vad3::choose_optimal_frame_size(320), 320);
322221
assert_eq!(hypr_vad3::choose_optimal_frame_size(480), 480);

crates/vad-ext/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
mod continuous;
22
mod continuous2;
33
mod error;
4+
mod streaming;
45

56
pub use continuous::*;
67
pub use continuous2::*;
78
pub use error::*;
9+
pub use streaming::*;
810

911
#[cfg(test)]
1012
pub mod tests {

0 commit comments

Comments
 (0)