|
1 | | -use swift_rs::swift; |
| 1 | +use std::sync::OnceLock; |
2 | 2 |
|
3 | | -swift!(fn initialize_am2_sdk()); |
| 3 | +use swift_rs::{swift, SRArray, SRObject, SRString}; |
4 | 4 |
|
5 | | -swift!(fn check_am2_ready() -> bool); |
| 5 | +swift!(fn initialize_am2_sdk(api_key: &SRString)); |
| 6 | +swift!(fn am2_vad_init() -> bool); |
| 7 | +swift!(fn am2_vad_detect(samples_ptr: *const f32, samples_len: i64) -> SRObject<VadResultArray>); |
| 8 | +swift!(fn am2_vad_index_to_seconds(index: i64) -> f32); |
| 9 | + |
| 10 | +static SDK_INITIALIZED: OnceLock<()> = OnceLock::new(); |
| 11 | + |
| 12 | +#[repr(C)] |
| 13 | +pub struct VadResultArray { |
| 14 | + pub data: SRArray<bool>, |
| 15 | +} |
6 | 16 |
|
7 | 17 | pub fn init() { |
8 | | - unsafe { |
9 | | - initialize_am2_sdk(); |
10 | | - } |
| 18 | + SDK_INITIALIZED.get_or_init(|| { |
| 19 | + let api_key = std::env::var("AM_API_KEY").unwrap_or_default(); |
| 20 | + let api_key = SRString::from(api_key.as_str()); |
| 21 | + unsafe { |
| 22 | + initialize_am2_sdk(&api_key); |
| 23 | + } |
| 24 | + }); |
11 | 25 | } |
12 | 26 |
|
13 | 27 | pub fn is_ready() -> bool { |
14 | | - unsafe { check_am2_ready() } |
| 28 | + SDK_INITIALIZED.get().is_some() |
| 29 | +} |
| 30 | + |
| 31 | +pub mod vad { |
| 32 | + use std::sync::OnceLock; |
| 33 | + |
| 34 | + use super::*; |
| 35 | + |
| 36 | + static VAD_INITIALIZED: OnceLock<bool> = OnceLock::new(); |
| 37 | + |
| 38 | + pub fn init() -> bool { |
| 39 | + *VAD_INITIALIZED.get_or_init(|| unsafe { am2_vad_init() }) |
| 40 | + } |
| 41 | + |
| 42 | + pub fn is_ready() -> bool { |
| 43 | + VAD_INITIALIZED.get().copied().unwrap_or(false) |
| 44 | + } |
| 45 | + |
| 46 | + pub fn detect(samples: &[f32]) -> Vec<bool> { |
| 47 | + let result = unsafe { am2_vad_detect(samples.as_ptr(), samples.len() as i64) }; |
| 48 | + result.data.as_slice().to_vec() |
| 49 | + } |
| 50 | + |
| 51 | + pub fn index_to_seconds(index: usize) -> f32 { |
| 52 | + unsafe { am2_vad_index_to_seconds(index as i64) } |
| 53 | + } |
| 54 | + |
| 55 | + #[derive(Debug, Clone)] |
| 56 | + pub struct VoiceSegment { |
| 57 | + pub start_seconds: f32, |
| 58 | + pub end_seconds: f32, |
| 59 | + } |
| 60 | + |
| 61 | + pub fn detect_segments(samples: &[f32]) -> Vec<VoiceSegment> { |
| 62 | + let voice_activity = detect(samples); |
| 63 | + let mut segments = Vec::new(); |
| 64 | + let mut in_voice = false; |
| 65 | + let mut segment_start = 0.0; |
| 66 | + |
| 67 | + for (i, &is_voice) in voice_activity.iter().enumerate() { |
| 68 | + if is_voice && !in_voice { |
| 69 | + segment_start = index_to_seconds(i); |
| 70 | + in_voice = true; |
| 71 | + } else if !is_voice && in_voice { |
| 72 | + segments.push(VoiceSegment { |
| 73 | + start_seconds: segment_start, |
| 74 | + end_seconds: index_to_seconds(i), |
| 75 | + }); |
| 76 | + in_voice = false; |
| 77 | + } |
| 78 | + } |
| 79 | + |
| 80 | + if in_voice { |
| 81 | + segments.push(VoiceSegment { |
| 82 | + start_seconds: segment_start, |
| 83 | + end_seconds: index_to_seconds(voice_activity.len()), |
| 84 | + }); |
| 85 | + } |
| 86 | + |
| 87 | + segments |
| 88 | + } |
15 | 89 | } |
16 | 90 |
|
17 | 91 | #[cfg(test)] |
18 | 92 | mod tests { |
19 | 93 | use super::*; |
20 | 94 |
|
21 | 95 | #[test] |
22 | | - fn test_am2_swift_compilation() { |
| 96 | + fn test_am2_sdk_init() { |
23 | 97 | init(); |
24 | 98 | assert!(is_ready()); |
25 | 99 | } |
| 100 | + |
| 101 | + #[test] |
| 102 | + fn test_am2_vad_init() { |
| 103 | + init(); |
| 104 | + assert!(vad::init()); |
| 105 | + assert!(vad::is_ready()); |
| 106 | + } |
26 | 107 | } |
0 commit comments