Skip to content

Commit a63e082

Browse files
feat(am2): add file transcription support using WhisperKitPro (#2158)
* feat(am2): add file transcription support using WhisperKitPro - Add Swift bindings for WhisperKitPro file transcription - Add am2_transcribe_init, am2_transcribe_is_ready, am2_transcribe_file functions - Add Rust wrapper module for transcription functionality - Add test for transcription initialization Co-Authored-By: yujonglee <[email protected]> * refactor(am2): make transcribe::is_ready() consistent with vad::is_ready() - Remove unnecessary FFI call in is_ready() to match vad module pattern - Remove unused am2_transcribe_is_ready FFI declaration Addresses code review feedback about design inconsistency. Co-Authored-By: yujonglee <[email protected]> --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: yujonglee <[email protected]>
1 parent ed76667 commit a63e082

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

crates/am2/src/lib.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ swift!(fn initialize_am2_sdk(api_key: &SRString));
66
swift!(fn am2_vad_init() -> bool);
77
swift!(fn am2_vad_detect(samples_ptr: *const f32, samples_len: i64) -> SRObject<VadResultArray>);
88
swift!(fn am2_vad_index_to_seconds(index: i64) -> f32);
9+
swift!(fn am2_transcribe_init(model: &SRString) -> bool);
10+
swift!(fn am2_transcribe_file(audio_path: &SRString) -> SRObject<TranscribeResultFFI>);
911

1012
static SDK_INITIALIZED: OnceLock<()> = OnceLock::new();
1113

@@ -14,6 +16,12 @@ pub struct VadResultArray {
1416
pub data: SRArray<bool>,
1517
}
1618

19+
#[repr(C)]
20+
pub struct TranscribeResultFFI {
21+
pub text: SRString,
22+
pub success: bool,
23+
}
24+
1725
pub fn init() {
1826
SDK_INITIALIZED.get_or_init(|| {
1927
let api_key = std::env::var("AM_API_KEY").unwrap_or_default();
@@ -88,6 +96,40 @@ pub mod vad {
8896
}
8997
}
9098

99+
pub mod transcribe {
100+
use std::sync::OnceLock;
101+
102+
use super::*;
103+
104+
static TRANSCRIBE_INITIALIZED: OnceLock<bool> = OnceLock::new();
105+
106+
pub fn init(model: &str) -> bool {
107+
*TRANSCRIBE_INITIALIZED.get_or_init(|| {
108+
let model = SRString::from(model);
109+
unsafe { am2_transcribe_init(&model) }
110+
})
111+
}
112+
113+
pub fn is_ready() -> bool {
114+
TRANSCRIBE_INITIALIZED.get().copied().unwrap_or(false)
115+
}
116+
117+
#[derive(Debug, Clone)]
118+
pub struct TranscribeResult {
119+
pub text: String,
120+
pub success: bool,
121+
}
122+
123+
pub fn transcribe_file(audio_path: &str) -> TranscribeResult {
124+
let audio_path = SRString::from(audio_path);
125+
let result = unsafe { am2_transcribe_file(&audio_path) };
126+
TranscribeResult {
127+
text: result.text.to_string(),
128+
success: result.success,
129+
}
130+
}
131+
}
132+
91133
#[cfg(test)]
92134
mod tests {
93135
use super::*;
@@ -104,4 +146,11 @@ mod tests {
104146
assert!(vad::init());
105147
assert!(vad::is_ready());
106148
}
149+
150+
#[test]
151+
fn test_am2_transcribe_init() {
152+
init();
153+
assert!(transcribe::init("large-v3-v20240930_626MB"));
154+
assert!(transcribe::is_ready());
155+
}
107156
}

crates/am2/swift-lib/src/lib.swift

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import Foundation
33
import SwiftRs
44

55
private var vadInstance: VoiceActivityDetector?
6+
private var whisperKitProInstance: WhisperKitPro?
67

78
@_cdecl("initialize_am2_sdk")
89
public func initialize_am2_sdk(apiKey: SRString) {
@@ -68,3 +69,67 @@ public func am2_vad_index_to_seconds(index: Int) -> Float {
6869
}
6970
return vad.voiceActivityIndexToSeconds(index)
7071
}
72+
73+
public class TranscribeResult: NSObject {
74+
var text: SRString
75+
var success: Bool
76+
77+
init(text: String, success: Bool) {
78+
self.text = SRString(text)
79+
self.success = success
80+
}
81+
}
82+
83+
@_cdecl("am2_transcribe_init")
84+
public func am2_transcribe_init(model: SRString) -> Bool {
85+
let semaphore = DispatchSemaphore(value: 0)
86+
var success = false
87+
88+
Task {
89+
do {
90+
let modelName = model.toString()
91+
let config = WhisperKitProConfig(model: modelName)
92+
whisperKitProInstance = try await WhisperKitPro(config)
93+
success = true
94+
} catch {
95+
success = false
96+
}
97+
semaphore.signal()
98+
}
99+
100+
semaphore.wait()
101+
return success
102+
}
103+
104+
@_cdecl("am2_transcribe_is_ready")
105+
public func am2_transcribe_is_ready() -> Bool {
106+
return whisperKitProInstance != nil
107+
}
108+
109+
@_cdecl("am2_transcribe_file")
110+
public func am2_transcribe_file(audioPath: SRString) -> TranscribeResult {
111+
let semaphore = DispatchSemaphore(value: 0)
112+
var resultText = ""
113+
var success = false
114+
115+
Task {
116+
guard let whisperKit = whisperKitProInstance else {
117+
semaphore.signal()
118+
return
119+
}
120+
121+
do {
122+
let path = audioPath.toString()
123+
let results = try await whisperKit.transcribe(audioPath: path)
124+
resultText = WhisperKitProUtils.mergeTranscriptionResults(results).text
125+
success = true
126+
} catch {
127+
resultText = ""
128+
success = false
129+
}
130+
semaphore.signal()
131+
}
132+
133+
semaphore.wait()
134+
return TranscribeResult(text: resultText, success: success)
135+
}

0 commit comments

Comments
 (0)