Skip to content

Commit 0703a98

Browse files
committed
fix am2 and add vvad
1 parent 78ae5c6 commit 0703a98

File tree

9 files changed

+257
-31
lines changed

9 files changed

+257
-31
lines changed
Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# https://docs.argmaxinc.com/guides/upgrading-to-pro-sdk
1+
# https://app.argmaxinc.com/docs/guides/upgrading-to-pro-sdk
22
# Sets up Argmax Pro SDK registry access for Swift Package Manager
33

44
name: "Setup Argmax SDK"
@@ -10,12 +10,8 @@ inputs:
1010
runs:
1111
using: "composite"
1212
steps:
13-
- name: Setup Argmax Registry
14-
shell: bash
15-
run: |
16-
swift package-registry set --global --scope argmaxinc https://api.argmaxinc.com/v1/sdk
13+
- shell: bash
14+
run: swift package-registry set --global --scope argmaxinc https://api.argmaxinc.com/v1/sdk
1715

18-
- name: Authenticate with Argmax
19-
shell: bash
20-
run: |
21-
swift package-registry login https://api.argmaxinc.com/v1/sdk/login --token "${{ inputs.api-token }}"
16+
- shell: bash
17+
run: swift package-registry login https://api.argmaxinc.com/v1/sdk/login --token "${{ inputs.api-token }}"

.github/workflows/desktop_cd.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ jobs:
9999
uses: ./.github/actions/argmax_sdk_setup
100100
with:
101101
api-token: ${{ secrets.AM_SECRET_TOKEN }}
102+
- if: ${{ matrix.include_am }}
103+
run: cargo test -p am2
102104
- run: pnpm -F ui build
103105
- if: ${{ matrix.include_am }}
104106
run: |

Cargo.lock

Lines changed: 16 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ hypr-turso = { path = "crates/turso", package = "turso" }
7979
hypr-vad-ext = { path = "crates/vad-ext", package = "vad-ext" }
8080
hypr-vad2 = { path = "crates/vad2", package = "vad2" }
8181
hypr-vad3 = { path = "crates/vad3", package = "vad3" }
82+
hypr-vvad = { path = "crates/vvad", package = "vvad" }
8283
hypr-whisper = { path = "crates/whisper", package = "whisper" }
8384
hypr-whisper-local = { path = "crates/whisper-local", package = "whisper-local" }
8485
hypr-whisper-local-model = { path = "crates/whisper-local-model", package = "whisper-local-model" }
@@ -256,7 +257,7 @@ objc2-user-notifications = "0.3"
256257

257258
tokenizers = "0.21.4"
258259

259-
swift-rs = { git = "https://github.com/yujonglee/swift-rs", branch = "fix-framework" }
260+
swift-rs = { git = "https://github.com/yujonglee/swift-rs", rev = "8c39ff4" }
260261
sysinfo = "0.37.0"
261262

262263
[patch.crates-io]

crates/am2/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ fn main() {
33
{
44
swift_rs::SwiftLinker::new("13.0")
55
.with_package("swift-lib", "./swift-lib/")
6+
.with_framework("ArgmaxSDK")
67
.link();
78
}
89
}

crates/am2/src/lib.rs

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,107 @@
1-
use swift_rs::swift;
1+
use std::sync::OnceLock;
22

3-
swift!(fn initialize_am2_sdk());
3+
use swift_rs::{swift, SRArray, SRObject, SRString};
44

5-
swift!(fn check_am2_ready() -> bool);
5+
swift!(fn initialize_am2_sdk(api_key: &SRString));
6+
swift!(fn am2_vad_init() -> bool);
7+
swift!(fn am2_vad_detect(samples_ptr: *const f32, samples_len: i64) -> SRObject<VadResultArray>);
8+
swift!(fn am2_vad_index_to_seconds(index: i64) -> f32);
9+
10+
static SDK_INITIALIZED: OnceLock<()> = OnceLock::new();
11+
12+
#[repr(C)]
13+
pub struct VadResultArray {
14+
pub data: SRArray<bool>,
15+
}
616

717
pub fn init() {
8-
unsafe {
9-
initialize_am2_sdk();
10-
}
18+
SDK_INITIALIZED.get_or_init(|| {
19+
let api_key = std::env::var("AM_API_KEY").unwrap_or_default();
20+
let api_key = SRString::from(api_key.as_str());
21+
unsafe {
22+
initialize_am2_sdk(&api_key);
23+
}
24+
});
1125
}
1226

1327
pub fn is_ready() -> bool {
14-
unsafe { check_am2_ready() }
28+
SDK_INITIALIZED.get().is_some()
29+
}
30+
31+
pub mod vad {
32+
use std::sync::OnceLock;
33+
34+
use super::*;
35+
36+
static VAD_INITIALIZED: OnceLock<bool> = OnceLock::new();
37+
38+
pub fn init() -> bool {
39+
*VAD_INITIALIZED.get_or_init(|| unsafe { am2_vad_init() })
40+
}
41+
42+
pub fn is_ready() -> bool {
43+
VAD_INITIALIZED.get().copied().unwrap_or(false)
44+
}
45+
46+
pub fn detect(samples: &[f32]) -> Vec<bool> {
47+
let result = unsafe { am2_vad_detect(samples.as_ptr(), samples.len() as i64) };
48+
result.data.as_slice().to_vec()
49+
}
50+
51+
pub fn index_to_seconds(index: usize) -> f32 {
52+
unsafe { am2_vad_index_to_seconds(index as i64) }
53+
}
54+
55+
#[derive(Debug, Clone)]
56+
pub struct VoiceSegment {
57+
pub start_seconds: f32,
58+
pub end_seconds: f32,
59+
}
60+
61+
pub fn detect_segments(samples: &[f32]) -> Vec<VoiceSegment> {
62+
let voice_activity = detect(samples);
63+
let mut segments = Vec::new();
64+
let mut in_voice = false;
65+
let mut segment_start = 0.0;
66+
67+
for (i, &is_voice) in voice_activity.iter().enumerate() {
68+
if is_voice && !in_voice {
69+
segment_start = index_to_seconds(i);
70+
in_voice = true;
71+
} else if !is_voice && in_voice {
72+
segments.push(VoiceSegment {
73+
start_seconds: segment_start,
74+
end_seconds: index_to_seconds(i),
75+
});
76+
in_voice = false;
77+
}
78+
}
79+
80+
if in_voice {
81+
segments.push(VoiceSegment {
82+
start_seconds: segment_start,
83+
end_seconds: index_to_seconds(voice_activity.len()),
84+
});
85+
}
86+
87+
segments
88+
}
1589
}
1690

1791
#[cfg(test)]
1892
mod tests {
1993
use super::*;
2094

2195
#[test]
22-
fn test_am2_swift_compilation() {
96+
fn test_am2_sdk_init() {
2397
init();
2498
assert!(is_ready());
2599
}
100+
101+
#[test]
102+
fn test_am2_vad_init() {
103+
init();
104+
assert!(vad::init());
105+
assert!(vad::is_ready());
106+
}
26107
}

crates/am2/swift-lib/src/lib.swift

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,69 @@ import Argmax
22
import Foundation
33
import SwiftRs
44

5-
private var isAM2Ready = false
5+
private var vadInstance: VoiceActivityDetector?
66

77
@_cdecl("initialize_am2_sdk")
8-
public func initialize_am2_sdk() {
9-
isAM2Ready = true
10-
print("AM2 SDK initialized successfully")
8+
public func initialize_am2_sdk(apiKey: SRString) {
9+
let semaphore = DispatchSemaphore(value: 0)
10+
11+
Task {
12+
let key = apiKey.toString()
13+
if !key.isEmpty {
14+
await ArgmaxSDK.with(ArgmaxConfig(apiKey: key))
15+
}
16+
semaphore.signal()
17+
}
18+
19+
semaphore.wait()
20+
}
21+
22+
public class VadResultArray: NSObject {
23+
var data: SRArray<Bool>
24+
25+
init(_ data: [Bool]) {
26+
self.data = SRArray(data)
27+
}
28+
}
29+
30+
@_cdecl("am2_vad_init")
31+
public func am2_vad_init() -> Bool {
32+
let semaphore = DispatchSemaphore(value: 0)
33+
var success = false
34+
35+
Task {
36+
do {
37+
vadInstance = try await VoiceActivityDetector.modelVAD()
38+
success = true
39+
} catch {
40+
success = false
41+
}
42+
semaphore.signal()
43+
}
44+
45+
semaphore.wait()
46+
return success
1147
}
1248

13-
@_cdecl("check_am2_ready")
14-
public func check_am2_ready() -> Bool {
15-
return isAM2Ready
49+
@_cdecl("am2_vad_detect")
50+
public func am2_vad_detect(
51+
samplesPtr: UnsafePointer<Float>,
52+
samplesLen: Int
53+
) -> VadResultArray {
54+
guard let vad = vadInstance else {
55+
return VadResultArray([])
56+
}
57+
58+
let audioArray = Array(UnsafeBufferPointer(start: samplesPtr, count: samplesLen))
59+
let voiceSegments = vad.voiceActivity(in: audioArray)
60+
61+
return VadResultArray(voiceSegments)
1662
}
1763

64+
@_cdecl("am2_vad_index_to_seconds")
65+
public func am2_vad_index_to_seconds(index: Int) -> Float {
66+
guard let vad = vadInstance else {
67+
return 0.0
68+
}
69+
return vad.voiceActivityIndexToSeconds(index)
70+
}

crates/vvad/Cargo.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[package]
2+
name = "vvad"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[target.'cfg(all(target_os = "macos", target_arch = "aarch64"))'.dependencies]
7+
hypr-am2 = { workspace = true }
8+
9+
[target.'cfg(not(all(target_os = "macos", target_arch = "aarch64")))'.dependencies]
10+
earshot = { workspace = true }

0 commit comments

Comments
 (0)