From f30f5c00a3a2c26f1be31173c8cac7c97f473e79 Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 22 Apr 2026 21:19:55 +0200 Subject: [PATCH 1/4] enable NNAPI on Android --- src-tauri/Cargo.toml | 1 + src-tauri/src/ai_processing.rs | 31 ++++++++++++++++++++++++------- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 2072d198f..c1f7446a3 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -80,6 +80,7 @@ ndk-context = "0.1" jni = "0.21" reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "rustls-tls-webpki-roots"] } include_dir = "0.7.3" +ort = { version = "2.0.0-rc.10", features = ["nnapi"] } [target.'cfg(target_os = "linux")'.dependencies] gtk = "0.18.2" diff --git a/src-tauri/src/ai_processing.rs b/src-tauri/src/ai_processing.rs index e6f0cae5b..f965a1272 100644 --- a/src-tauri/src/ai_processing.rs +++ b/src-tauri/src/ai_processing.rs @@ -10,6 +10,9 @@ use image::{ }; use ndarray::{Array, Array4, IxDyn}; use ort::session::Session; +use ort::session::builder::SessionBuilder; +#[cfg(target_os = "android")] +use ort::execution_providers::NNAPIExecutionProvider; use ort::value::Tensor; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; @@ -94,6 +97,20 @@ pub struct AiState { pub depth_map: Option, } +fn add_platform_optimization(builder: SessionBuilder) -> Result { + #[cfg(target_os = "android")] + { + return Ok(builder.with_execution_providers([ + NNAPIExecutionProvider::default().build() + ])?); + } + + #[cfg(not(target_os = "android"))] + { + Ok(builder) + } +} + fn edt_1d(f: &mut [f32], v: &mut [usize], z: &mut [f32], d: &mut [f32]) { let n = f.len(); if n == 0 { @@ -301,11 +318,11 @@ pub async fn get_or_init_ai_models( let sky_seg_path = models_dir.join(SKYSEG_FILENAME); let depth_path = models_dir.join(DEPTH_FILENAME); - let sam_encoder = Session::builder()?.commit_from_file(encoder_path)?; - let sam_decoder = Session::builder()?.commit_from_file(decoder_path)?; - let u2netp = Session::builder()?.commit_from_file(u2netp_path)?; - let sky_seg = Session::builder()?.commit_from_file(sky_seg_path)?; - let depth_anything = Session::builder()?.commit_from_file(depth_path)?; + let sam_encoder = add_platform_optimization(Session::builder()?)?.commit_from_file(encoder_path)?; + let sam_decoder = add_platform_optimization(Session::builder()?)?.commit_from_file(decoder_path)?; + let u2netp = add_platform_optimization(Session::builder()?)?.commit_from_file(u2netp_path)?; + let sky_seg = add_platform_optimization(Session::builder()?)?.commit_from_file(sky_seg_path)?; + let depth_anything = add_platform_optimization(Session::builder()?)?.commit_from_file(depth_path)?; crate::register_exit_handler(); @@ -372,7 +389,7 @@ pub async fn get_or_init_denoise_model( let _ = ort::init().with_name("AI-Denoise").commit(); let model_path = models_dir.join(DENOISE_FILENAME); - let session = Session::builder()?.commit_from_file(model_path)?; + let session = add_platform_optimization(Session::builder()?)?.commit_from_file(model_path)?; let denoise_model = Arc::new(Mutex::new(session)); crate::register_exit_handler(); @@ -503,7 +520,7 @@ pub async fn get_or_init_lama_model( let _ = ort::init().with_name("AI-Inpainting").commit(); let model_path = models_dir.join(LAMA_FILENAME); - let session = Session::builder()?.commit_from_file(model_path)?; + let session = add_platform_optimization(Session::builder()?)?.commit_from_file(model_path)?; let lama_model = Arc::new(Mutex::new(session)); crate::register_exit_handler(); From eae29c53fa5824e17dc8d5c7a5f20119683496a9 Mon Sep 17 00:00:00 2001 From: Max Date: Thu, 23 Apr 2026 17:47:34 +0200 Subject: [PATCH 2/4] use MobileSAM instead of SAM2 on Android --- src-tauri/src/ai_processing.rs | 109 ++++++++++++++++++++++++--------- 1 file changed, 81 insertions(+), 28 deletions(-) diff --git a/src-tauri/src/ai_processing.rs b/src-tauri/src/ai_processing.rs index f965a1272..9da0dfc21 100644 --- a/src-tauri/src/ai_processing.rs +++ b/src-tauri/src/ai_processing.rs @@ -21,13 +21,28 @@ use tauri::Manager; use tokenizers::Tokenizer; use tokio::sync::Mutex as TokioMutex; -const ENCODER_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/sam_vit_b_01ec64_encoder.onnx?download=true"; -const DECODER_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/sam_vit_b_01ec64_decoder.onnx?download=true"; -const ENCODER_FILENAME: &str = "sam_vit_b_01ec64_encoder.onnx"; -const DECODER_FILENAME: &str = "sam_vit_b_01ec64_decoder.onnx"; +#[cfg(target_os = "android")] +mod sam_cfg { + pub const ENCODER_URL: &str = "https://huggingface.co/Acly/MobileSAM/resolve/main/mobile_sam_image_encoder.onnx?download=true"; + pub const DECODER_URL: &str = "https://huggingface.co/Acly/MobileSAM/resolve/main/sam_mask_decoder_single.onnx?download=true"; + pub const ENCODER_FILENAME: &str = "mobile_sam_image_encoder.onnx"; + pub const DECODER_FILENAME: &str = "sam_mask_decoder_single.onnx"; + pub const ENCODER_SHA256: &str = "580f5fb648ea1062c0aabc26217aed56921985f03f0cbbd852bba81d760cc749"; + pub const DECODER_SHA256: &str = "93915fc7c993ab9d59ab8c9ccd3bce37f7509c81ab4150a74abd4d2abbd8570d"; +} + +#[cfg(not(target_os = "android"))] +mod sam_cfg { + pub const ENCODER_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/sam_vit_b_01ec64_encoder.onnx?download=true"; + pub const DECODER_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/sam_vit_b_01ec64_decoder.onnx?download=true"; + pub const ENCODER_FILENAME: &str = "sam_vit_b_01ec64_encoder.onnx"; + pub const DECODER_FILENAME: &str = "sam_vit_b_01ec64_decoder.onnx"; + pub const ENCODER_SHA256: &str = "16ab73d9c824886f0de2938c19df22fb9ec3deebfd0de58e65177e479213d7d1"; + pub const DECODER_SHA256: &str = "85d0d672cf5b7fe763edcde429e5533e62f674af4b15c7d688b7673b0ef00bf7"; +} + +use sam_cfg::*; const SAM_INPUT_SIZE: u32 = 1024; -const ENCODER_SHA256: &str = "16ab73d9c824886f0de2938c19df22fb9ec3deebfd0de58e65177e479213d7d1"; -const DECODER_SHA256: &str = "85d0d672cf5b7fe763edcde429e5533e62f674af4b15c7d688b7673b0ef00bf7"; const U2NETP_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/u2net.onnx?download=true"; @@ -941,32 +956,70 @@ pub fn generate_image_embeddings( let (actual_width, actual_height) = rgb_image.dimensions(); let raw_pixels = rgb_image.as_raw(); - let mut input_tensor: Array = - Array::zeros((1, 3, SAM_INPUT_SIZE as usize, SAM_INPUT_SIZE as usize)); + #[cfg(target_os = "android")] + { + let mut input_tensor = Array::::zeros(( + SAM_INPUT_SIZE as usize, + SAM_INPUT_SIZE as usize, + 3 + )); - let w_usize = actual_width as usize; - for y in 0..(actual_height as usize) { - for x in 0..w_usize { - let idx = (y * w_usize + x) * 3; - input_tensor[[0, 0, y, x]] = raw_pixels[idx]; - input_tensor[[0, 1, y, x]] = raw_pixels[idx + 1]; - input_tensor[[0, 2, y, x]] = raw_pixels[idx + 2]; + let w_usize = actual_width as usize; + for y in 0..actual_height { + let y_idx = y as usize; + let y_offset = y_idx * w_usize; + for x in 0..actual_width { + let x_idx = x as usize; + let idx = (y_offset + x_idx) * 3; + + input_tensor[[y_idx, x_idx, 0]] = raw_pixels[idx] as f32; + input_tensor[[y_idx, x_idx, 1]] = raw_pixels[idx + 1] as f32; + input_tensor[[y_idx, x_idx, 2]] = raw_pixels[idx + 2] as f32; + } } + + let input_tensor_dyn = input_tensor.into_dyn().as_standard_layout().into_owned(); + let input_tensor_ort = Tensor::from_array(input_tensor_dyn)?; + let mut session = encoder.lock().unwrap(); + + let outputs = session.run(ort::inputs!["input_image" => input_tensor_ort])?; + let embeddings = outputs[0].try_extract_array::()?.to_owned(); + + Ok(ImageEmbeddings { + path_hash: "".to_string(), + embeddings: embeddings.into_dyn(), + original_size: (orig_width, orig_height), + }) } + #[cfg(not(target_os = "android"))] + { + let mut input_tensor: Array = + Array::zeros((1, 3, SAM_INPUT_SIZE as usize, SAM_INPUT_SIZE as usize)); + + let w_usize = actual_width as usize; + for y in 0..(actual_height as usize) { + for x in 0..w_usize { + let idx = (y * w_usize + x) * 3; + input_tensor[[0, 0, y, x]] = raw_pixels[idx]; + input_tensor[[0, 1, y, x]] = raw_pixels[idx + 1]; + input_tensor[[0, 2, y, x]] = raw_pixels[idx + 2]; + } + } - let input_tensor_dyn = input_tensor.into_dyn(); - let input_values = input_tensor_dyn.as_standard_layout(); - let input_tensor_ort = Tensor::from_array(input_values.into_owned())?; - let mut session = encoder.lock().unwrap(); - let outputs = session.run(ort::inputs![input_tensor_ort])?; - - let embeddings = outputs[0].try_extract_array::()?.to_owned(); - - Ok(ImageEmbeddings { - path_hash: "".to_string(), - embeddings: embeddings.into_dyn(), - original_size: (orig_width, orig_height), - }) + let input_tensor_dyn = input_tensor.into_dyn(); + let input_values = input_tensor_dyn.as_standard_layout(); + let input_tensor_ort = Tensor::from_array(input_values.into_owned())?; + let mut session = encoder.lock().unwrap(); + let outputs = session.run(ort::inputs![input_tensor_ort])?; + + let embeddings = outputs[0].try_extract_array::()?.to_owned(); + + Ok(ImageEmbeddings { + path_hash: "".to_string(), + embeddings: embeddings.into_dyn(), + original_size: (orig_width, orig_height), + }) + } } pub fn run_sam_decoder( From 807f74f44cfd57d40474a6293e1a3656269ae704 Mon Sep 17 00:00:00 2001 From: Max Date: Fri, 24 Apr 2026 15:47:17 +0200 Subject: [PATCH 3/4] unload unused AI models quickly on android to reduce memory pressure --- src-tauri/src/ai_processing.rs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src-tauri/src/ai_processing.rs b/src-tauri/src/ai_processing.rs index 9da0dfc21..56686c66f 100644 --- a/src-tauri/src/ai_processing.rs +++ b/src-tauri/src/ai_processing.rs @@ -112,6 +112,17 @@ pub struct AiState { pub depth_map: Option, } +pub fn clear_all_models(ai_state_lock: &mut Option) { + if let Some(state) = ai_state_lock.as_mut() { + state.models = None; + state.denoise_model = None; + state.clip_models = None; + state.lama_model = None; + state.embeddings = None; + state.depth_map = None; + } +} + fn add_platform_optimization(builder: SessionBuilder) -> Result { #[cfg(target_os = "android")] { @@ -267,6 +278,13 @@ pub async fn get_or_init_ai_models( } let _guard = ai_init_lock.lock().await; + #[cfg(target_os = "android")] + { + let mut ai_state_lock = ai_state_mutex.lock().unwrap(); + if ai_state_lock.as_ref().and_then(|s| s.models.as_ref()).is_none() { + clear_all_models(&mut ai_state_lock); + } + } if let Some(models) = ai_state_mutex .lock() @@ -381,6 +399,13 @@ pub async fn get_or_init_denoise_model( } let _guard = ai_init_lock.lock().await; + #[cfg(target_os = "android")] + { + let mut ai_state_lock = ai_state_mutex.lock().unwrap(); + if ai_state_lock.as_ref().and_then(|s| s.denoise_model.as_ref()).is_none() { + clear_all_models(&mut ai_state_lock); + } + } if let Some(denoise_model) = ai_state_mutex .lock() @@ -512,6 +537,13 @@ pub async fn get_or_init_lama_model( } let _guard = ai_init_lock.lock().await; + #[cfg(target_os = "android")] + { + let mut ai_state_lock = ai_state_mutex.lock().unwrap(); + if ai_state_lock.as_ref().and_then(|s| s.lama_model.as_ref()).is_none() { + clear_all_models(&mut ai_state_lock); + } + } if let Some(lama_model) = ai_state_mutex .lock() From 8cf26796cf8ea24bb8e95c85e40cd710d2ab1ea5 Mon Sep 17 00:00:00 2001 From: Max Date: Fri, 24 Apr 2026 23:19:57 +0200 Subject: [PATCH 4/4] fix inpainting on Android --- src-tauri/src/ai_processing.rs | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src-tauri/src/ai_processing.rs b/src-tauri/src/ai_processing.rs index 56686c66f..88feaaab1 100644 --- a/src-tauri/src/ai_processing.rs +++ b/src-tauri/src/ai_processing.rs @@ -29,6 +29,10 @@ mod sam_cfg { pub const DECODER_FILENAME: &str = "sam_mask_decoder_single.onnx"; pub const ENCODER_SHA256: &str = "580f5fb648ea1062c0aabc26217aed56921985f03f0cbbd852bba81d760cc749"; pub const DECODER_SHA256: &str = "93915fc7c993ab9d59ab8c9ccd3bce37f7509c81ab4150a74abd4d2abbd8570d"; + + pub const LAMA_URL: &str = "https://huggingface.co/qualcomm/LaMa-Dilated/resolve/ab898502c9bd764a50eb2719a309694b43eae658/LaMa-Dilated.onnx?download=true"; + pub const LAMA_FILENAME: &str = "LaMa-Dilated.onnx"; + pub const LAMA_SHA256: &str = "6f9e1d401eb67a63fb1be6c0cf3283d800bf4c20656028f96b044fedc382d762"; } #[cfg(not(target_os = "android"))] @@ -39,6 +43,10 @@ mod sam_cfg { pub const DECODER_FILENAME: &str = "sam_vit_b_01ec64_decoder.onnx"; pub const ENCODER_SHA256: &str = "16ab73d9c824886f0de2938c19df22fb9ec3deebfd0de58e65177e479213d7d1"; pub const DECODER_SHA256: &str = "85d0d672cf5b7fe763edcde429e5533e62f674af4b15c7d688b7673b0ef00bf7"; + + pub const LAMA_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/lama_fp16.onnx?download=true"; + pub const LAMA_FILENAME: &str = "lama_fp16.onnx"; + pub const LAMA_SHA256: &str = "2d6be6277c400d6f1b91819737f7c3da935e5c63d1b521d393be1196a2bfa82c"; } use sam_cfg::*; @@ -66,11 +74,6 @@ const DENOISE_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/res const DENOISE_FILENAME: &str = "nind_denoise_utnet_684.onnx"; const DENOISE_SHA256: &str = "ee3586279d514df557ff3f7dec6df37fafc51ba5d3a3435b2cc9ac2d9017e7fe"; -const LAMA_URL: &str = - "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/lama_fp16.onnx?download=true"; -const LAMA_FILENAME: &str = "lama_fp16.onnx"; -const LAMA_SHA256: &str = "2d6be6277c400d6f1b91819737f7c3da935e5c63d1b521d393be1196a2bfa82c"; - const DEPTH_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/depth_anything_v2_vits.onnx?download=true"; const DEPTH_FILENAME: &str = "depth_anything_v2_vits.onnx"; const DEPTH_INPUT_SIZE: u32 = 518; @@ -567,7 +570,7 @@ pub async fn get_or_init_lama_model( let _ = ort::init().with_name("AI-Inpainting").commit(); let model_path = models_dir.join(LAMA_FILENAME); - let session = add_platform_optimization(Session::builder()?)?.commit_from_file(model_path)?; + let session = Session::builder()?.commit_from_file(model_path)?; let lama_model = Arc::new(Mutex::new(session)); crate::register_exit_handler(); @@ -879,6 +882,9 @@ pub fn run_lama_inpainting( let cropped_img = imageops::crop_imm(&rgba, x0, y0, crop_w, crop_h).to_image(); let cropped_mask = imageops::crop_imm(mask, x0, y0, crop_w, crop_h).to_image(); + #[cfg(target_os = "android")] + let max_dim_limit: u32 = 512; + #[cfg(not(target_os = "android"))] let max_dim_limit: u32 = 768; let needs_downscale = crop_w > max_dim_limit || crop_h > max_dim_limit; @@ -935,9 +941,14 @@ pub fn run_lama_inpainting( let mut result_inf = RgbaImage::new(fw, fh); for y in 0..fh { for x in 0..fw { - let r = output_tensor[[0, 0, y as usize, x as usize]].clamp(0.0, 255.0) as u8; - let g = output_tensor[[0, 1, y as usize, x as usize]].clamp(0.0, 255.0) as u8; - let b = output_tensor[[0, 2, y as usize, x as usize]].clamp(0.0, 255.0) as u8; + #[cfg(target_os = "android")] + let multiplier = 255.0; + #[cfg(not(target_os = "android"))] + let multiplier = 1.0; + + let r = (output_tensor[[0, 0, y as usize, x as usize]] * multiplier).clamp(0.0, 255.0) as u8; + let g = (output_tensor[[0, 1, y as usize, x as usize]] * multiplier).clamp(0.0, 255.0) as u8; + let b = (output_tensor[[0, 2, y as usize, x as usize]] * multiplier).clamp(0.0, 255.0) as u8; result_inf.put_pixel(x, y, Rgba([r, g, b, 255])); } }