Skip to content

Commit 8fdc828

Browse files
committed
update am
1 parent 9c12f08 commit 8fdc828

File tree

4 files changed

+95
-7
lines changed

4 files changed

+95
-7
lines changed

apps/web/content/docs/developers/16.local-models.mdx

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,19 @@ description: "Learn about local models in Hyprnote"
1313

1414
- All STT models are stored in `~/Library/Application Support/hyprnote/models/stt/`
1515

16+
### Manual Download
17+
1618
| URL | Folder |
1719
| --- | --- |
1820
| https://huggingface.co/argmaxinc/whisperkit-pro/tree/main/openai_whisper-large-v3-v20240930_626MB | `openai_whisper-large-v3-v20240930_626MB` |
1921
| https://huggingface.co/argmaxinc/parakeetkit-pro/tree/main/nvidia_parakeet-v2_476MB | `nvidia_parakeet-v2_476MB` |
2022
| https://huggingface.co/argmaxinc/parakeetkit-pro/tree/main/nvidia_parakeet-v3_494MB | `nvidia_parakeet-v3_494MB` |
23+
| https://huggingface.co/argmaxinc/ctckit-pro/tree/main/parakeet-tdt_ctc-110m | `parakeet-tdt_ctc-110m` |
24+
25+
If you have [Huggingface CLI](https://huggingface.co/docs/huggingface_hub/guides/cli) installed:
26+
27+
```bash
28+
hf download argmaxinc/ctckit-pro \
29+
--include "parakeet-tdt_ctc-110m/*" \
30+
--local-dir "$HOME/Library/Application Support/hyprnote/models/stt"
31+
```

crates/am/src/client.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,40 @@ impl Client {
3333
}
3434
}
3535

36+
pub async fn wait_for_ready(
37+
&self,
38+
max_wait_time: Option<u32>,
39+
poll_interval: Option<u32>,
40+
) -> Result<ServerStatus, Error> {
41+
let url = format!("{}/waitForReady", self.base_url);
42+
let mut request = self.client.get(&url);
43+
44+
if let Some(max_wait) = max_wait_time {
45+
request = request.query(&[("maxWaitTime", max_wait)]);
46+
}
47+
if let Some(poll) = poll_interval {
48+
request = request.query(&[("pollInterval", poll)]);
49+
}
50+
51+
let response = request.send().await?;
52+
53+
match response.status() {
54+
StatusCode::OK => Ok(response.json().await?),
55+
StatusCode::BAD_REQUEST | StatusCode::REQUEST_TIMEOUT => {
56+
Err(self.handle_error_response(response).await)
57+
}
58+
_ => Err(Error::UnexpectedResponse),
59+
}
60+
}
61+
3662
pub async fn init(&self, request: InitRequest) -> Result<InitResponse, Error> {
3763
if !request.api_key.starts_with("ax_") {
3864
return Err(Error::InvalidApiKey);
3965
}
4066

4167
let url = format!("{}/init", self.base_url);
4268
let response = self.client.post(&url).json(&request).send().await?;
69+
println!("{:?}", request);
4370

4471
match response.status() {
4572
StatusCode::OK => Ok(response.json().await?),
@@ -100,9 +127,18 @@ impl InitRequest {
100127
Self {
101128
api_key: api_key.into(),
102129
model: None,
130+
model_token: None,
131+
download_base: None,
103132
model_repo: None,
104133
model_folder: None,
134+
tokenizer_folder: None,
135+
fast_load: None,
136+
fast_load_encoder_compute_units: None,
137+
fast_load_decoder_compute_units: None,
138+
model_vad: None,
139+
verbose: None,
105140
custom_vocabulary: None,
141+
custom_vocabulary_model_folder: None,
106142
}
107143
}
108144

@@ -124,6 +160,13 @@ impl InitRequest {
124160
match model {
125161
crate::AmModel::ParakeetV2 => {
126162
self.custom_vocabulary = Some(vec![]);
163+
self.custom_vocabulary_model_folder = Some(
164+
base_dir
165+
.as_ref()
166+
.join("parakeet-tdt_ctc-110m")
167+
.to_string_lossy()
168+
.to_string(),
169+
);
127170
}
128171
_ => {
129172
self.custom_vocabulary = None;

crates/am/src/lib.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,7 @@ mod tests {
1818
let status = client.status().await;
1919
println!("{:?}", status);
2020
client
21-
.init(InitRequest {
22-
api_key: "".to_string(),
23-
model: Some("nvidia_parakeet-v2_476MB".to_string()),
24-
model_repo: Some("argmaxinc/parakeetkit-pro".to_string()),
25-
model_folder: None,
26-
custom_vocabulary: Some(vec![]),
27-
})
21+
.init(InitRequest::new("").with_model(AmModel::ParakeetV2, "/tmp"))
2822
.await
2923
.unwrap();
3024
assert!(true);

crates/am/src/types.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,20 @@ common_derives! {
1414
pub model_state: ModelState,
1515
pub verbose: bool,
1616
#[serde(skip_serializing_if = "Option::is_none")]
17+
pub download_progress: Option<DownloadProgress>,
18+
#[serde(skip_serializing_if = "Option::is_none")]
1719
pub message: Option<String>,
1820
}
1921
}
2022

23+
common_derives! {
24+
#[serde(rename_all = "camelCase")]
25+
pub struct DownloadProgress {
26+
pub progress_percentage: f64,
27+
pub is_downloading: bool,
28+
}
29+
}
30+
2131
common_derives! {
2232
#[derive(Eq, PartialEq)]
2333
#[serde(rename_all = "lowercase")]
@@ -51,11 +61,41 @@ common_derives! {
5161
#[serde(skip_serializing_if = "Option::is_none")]
5262
pub model: Option<String>,
5363
#[serde(skip_serializing_if = "Option::is_none")]
64+
pub model_token: Option<String>,
65+
#[serde(skip_serializing_if = "Option::is_none")]
66+
pub download_base: Option<String>,
67+
#[serde(skip_serializing_if = "Option::is_none")]
5468
pub model_repo: Option<String>,
5569
#[serde(skip_serializing_if = "Option::is_none")]
5670
pub model_folder: Option<String>,
5771
#[serde(skip_serializing_if = "Option::is_none")]
72+
pub tokenizer_folder: Option<String>,
73+
#[serde(skip_serializing_if = "Option::is_none")]
74+
pub fast_load: Option<bool>,
75+
#[serde(skip_serializing_if = "Option::is_none")]
76+
pub fast_load_encoder_compute_units: Option<ComputeUnits>,
77+
#[serde(skip_serializing_if = "Option::is_none")]
78+
pub fast_load_decoder_compute_units: Option<ComputeUnits>,
79+
#[serde(skip_serializing_if = "Option::is_none")]
80+
pub model_vad: Option<bool>,
81+
#[serde(skip_serializing_if = "Option::is_none")]
82+
pub verbose: Option<bool>,
83+
#[serde(skip_serializing_if = "Option::is_none")]
5884
pub custom_vocabulary: Option<Vec<String>>,
85+
#[serde(skip_serializing_if = "Option::is_none")]
86+
pub custom_vocabulary_model_folder: Option<String>,
87+
}
88+
}
89+
90+
common_derives! {
91+
#[serde(rename_all = "lowercase")]
92+
pub enum ComputeUnits {
93+
Cpu,
94+
#[serde(rename = "cpuandgpu")]
95+
CpuAndGpu,
96+
#[serde(rename = "cpuandneuralengine")]
97+
CpuAndNeuralEngine,
98+
All,
5999
}
60100
}
61101

0 commit comments

Comments
 (0)