Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ ndk-context = "0.1"
jni = "0.21"
reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "rustls-tls-webpki-roots"] }
include_dir = "0.7.3"
ort = { version = "2.0.0-rc.10", features = ["nnapi"] }

[target.'cfg(target_os = "linux")'.dependencies]
gtk = "0.18.2"
Expand Down
197 changes: 155 additions & 42 deletions src-tauri/src/ai_processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
use anyhow::Result;
use image::imageops::{self, FilterType};
use image::{
DynamicImage, GenericImageView, GrayImage, ImageBuffer, Luma, Rgb, Rgb32FImage, Rgba, RgbaImage,

Check warning on line 9 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs

Check warning on line 9 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs
};
use ndarray::{Array, Array4, IxDyn};
use ort::session::Session;
use ort::session::builder::SessionBuilder;
#[cfg(target_os = "android")]
use ort::execution_providers::NNAPIExecutionProvider;
use ort::value::Tensor;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
Expand All @@ -18,13 +21,36 @@
use tokenizers::Tokenizer;
use tokio::sync::Mutex as TokioMutex;

const ENCODER_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/sam_vit_b_01ec64_encoder.onnx?download=true";
const DECODER_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/sam_vit_b_01ec64_decoder.onnx?download=true";
const ENCODER_FILENAME: &str = "sam_vit_b_01ec64_encoder.onnx";
const DECODER_FILENAME: &str = "sam_vit_b_01ec64_decoder.onnx";
#[cfg(target_os = "android")]
mod sam_cfg {
pub const ENCODER_URL: &str = "https://huggingface.co/Acly/MobileSAM/resolve/main/mobile_sam_image_encoder.onnx?download=true";
pub const DECODER_URL: &str = "https://huggingface.co/Acly/MobileSAM/resolve/main/sam_mask_decoder_single.onnx?download=true";

Check warning on line 27 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs

Check warning on line 27 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs
pub const ENCODER_FILENAME: &str = "mobile_sam_image_encoder.onnx";
pub const DECODER_FILENAME: &str = "sam_mask_decoder_single.onnx";
pub const ENCODER_SHA256: &str = "580f5fb648ea1062c0aabc26217aed56921985f03f0cbbd852bba81d760cc749";
pub const DECODER_SHA256: &str = "93915fc7c993ab9d59ab8c9ccd3bce37f7509c81ab4150a74abd4d2abbd8570d";

pub const LAMA_URL: &str = "https://huggingface.co/qualcomm/LaMa-Dilated/resolve/ab898502c9bd764a50eb2719a309694b43eae658/LaMa-Dilated.onnx?download=true";
pub const LAMA_FILENAME: &str = "LaMa-Dilated.onnx";
pub const LAMA_SHA256: &str = "6f9e1d401eb67a63fb1be6c0cf3283d800bf4c20656028f96b044fedc382d762";

Check warning on line 35 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs

Check warning on line 35 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs
}

#[cfg(not(target_os = "android"))]
mod sam_cfg {
pub const ENCODER_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/sam_vit_b_01ec64_encoder.onnx?download=true";
pub const DECODER_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/sam_vit_b_01ec64_decoder.onnx?download=true";

Check warning on line 41 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs

Check warning on line 41 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs
pub const ENCODER_FILENAME: &str = "sam_vit_b_01ec64_encoder.onnx";
pub const DECODER_FILENAME: &str = "sam_vit_b_01ec64_decoder.onnx";
pub const ENCODER_SHA256: &str = "16ab73d9c824886f0de2938c19df22fb9ec3deebfd0de58e65177e479213d7d1";
pub const DECODER_SHA256: &str = "85d0d672cf5b7fe763edcde429e5533e62f674af4b15c7d688b7673b0ef00bf7";

pub const LAMA_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/lama_fp16.onnx?download=true";
pub const LAMA_FILENAME: &str = "lama_fp16.onnx";
pub const LAMA_SHA256: &str = "2d6be6277c400d6f1b91819737f7c3da935e5c63d1b521d393be1196a2bfa82c";

Check warning on line 49 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs

Check warning on line 49 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs
}

use sam_cfg::*;
const SAM_INPUT_SIZE: u32 = 1024;
const ENCODER_SHA256: &str = "16ab73d9c824886f0de2938c19df22fb9ec3deebfd0de58e65177e479213d7d1";
const DECODER_SHA256: &str = "85d0d672cf5b7fe763edcde429e5533e62f674af4b15c7d688b7673b0ef00bf7";

const U2NETP_URL: &str =
"https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/u2net.onnx?download=true";
Expand All @@ -48,11 +74,6 @@
const DENOISE_FILENAME: &str = "nind_denoise_utnet_684.onnx";
const DENOISE_SHA256: &str = "ee3586279d514df557ff3f7dec6df37fafc51ba5d3a3435b2cc9ac2d9017e7fe";

const LAMA_URL: &str =
"https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/lama_fp16.onnx?download=true";
const LAMA_FILENAME: &str = "lama_fp16.onnx";
const LAMA_SHA256: &str = "2d6be6277c400d6f1b91819737f7c3da935e5c63d1b521d393be1196a2bfa82c";

const DEPTH_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/depth_anything_v2_vits.onnx?download=true";
const DEPTH_FILENAME: &str = "depth_anything_v2_vits.onnx";
const DEPTH_INPUT_SIZE: u32 = 518;
Expand Down Expand Up @@ -94,6 +115,31 @@
pub depth_map: Option<CachedDepthMap>,
}

pub fn clear_all_models(ai_state_lock: &mut Option<AiState>) {
if let Some(state) = ai_state_lock.as_mut() {
state.models = None;
state.denoise_model = None;
state.clip_models = None;
state.lama_model = None;
state.embeddings = None;
state.depth_map = None;
}
}

fn add_platform_optimization(builder: SessionBuilder) -> Result<SessionBuilder> {

Check warning on line 129 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs

Check warning on line 129 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs
#[cfg(target_os = "android")]
{
return Ok(builder.with_execution_providers([
NNAPIExecutionProvider::default().build()
])?);
}

#[cfg(not(target_os = "android"))]
{
Ok(builder)
}
}

fn edt_1d(f: &mut [f32], v: &mut [usize], z: &mut [f32], d: &mut [f32]) {
let n = f.len();
if n == 0 {
Expand Down Expand Up @@ -235,6 +281,13 @@
}

let _guard = ai_init_lock.lock().await;
#[cfg(target_os = "android")]

Check warning on line 284 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs

Check warning on line 284 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs
{
let mut ai_state_lock = ai_state_mutex.lock().unwrap();
if ai_state_lock.as_ref().and_then(|s| s.models.as_ref()).is_none() {
clear_all_models(&mut ai_state_lock);
}
}

if let Some(models) = ai_state_mutex
.lock()
Expand Down Expand Up @@ -298,14 +351,14 @@
let encoder_path = models_dir.join(ENCODER_FILENAME);
let decoder_path = models_dir.join(DECODER_FILENAME);
let u2netp_path = models_dir.join(U2NETP_FILENAME);
let sky_seg_path = models_dir.join(SKYSEG_FILENAME);

Check warning on line 354 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs

Check warning on line 354 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs
let depth_path = models_dir.join(DEPTH_FILENAME);

let sam_encoder = Session::builder()?.commit_from_file(encoder_path)?;
let sam_decoder = Session::builder()?.commit_from_file(decoder_path)?;
let u2netp = Session::builder()?.commit_from_file(u2netp_path)?;
let sky_seg = Session::builder()?.commit_from_file(sky_seg_path)?;
let depth_anything = Session::builder()?.commit_from_file(depth_path)?;
let sam_encoder = add_platform_optimization(Session::builder()?)?.commit_from_file(encoder_path)?;
let sam_decoder = add_platform_optimization(Session::builder()?)?.commit_from_file(decoder_path)?;
let u2netp = add_platform_optimization(Session::builder()?)?.commit_from_file(u2netp_path)?;
let sky_seg = add_platform_optimization(Session::builder()?)?.commit_from_file(sky_seg_path)?;
let depth_anything = add_platform_optimization(Session::builder()?)?.commit_from_file(depth_path)?;

crate::register_exit_handler();

Expand Down Expand Up @@ -349,6 +402,13 @@
}

let _guard = ai_init_lock.lock().await;
#[cfg(target_os = "android")]

Check warning on line 405 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs

Check warning on line 405 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs
{
let mut ai_state_lock = ai_state_mutex.lock().unwrap();
if ai_state_lock.as_ref().and_then(|s| s.denoise_model.as_ref()).is_none() {
clear_all_models(&mut ai_state_lock);
}
}

if let Some(denoise_model) = ai_state_mutex
.lock()
Expand All @@ -372,7 +432,7 @@

let _ = ort::init().with_name("AI-Denoise").commit();
let model_path = models_dir.join(DENOISE_FILENAME);
let session = Session::builder()?.commit_from_file(model_path)?;
let session = add_platform_optimization(Session::builder()?)?.commit_from_file(model_path)?;
let denoise_model = Arc::new(Mutex::new(session));

crate::register_exit_handler();
Expand Down Expand Up @@ -480,6 +540,13 @@
}

let _guard = ai_init_lock.lock().await;
#[cfg(target_os = "android")]

Check warning on line 543 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs

Check warning on line 543 in src-tauri/src/ai_processing.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/RapidRAW/RapidRAW/src-tauri/src/ai_processing.rs
{
let mut ai_state_lock = ai_state_mutex.lock().unwrap();
if ai_state_lock.as_ref().and_then(|s| s.lama_model.as_ref()).is_none() {
clear_all_models(&mut ai_state_lock);
}
}

if let Some(lama_model) = ai_state_mutex
.lock()
Expand Down Expand Up @@ -815,6 +882,9 @@
let cropped_img = imageops::crop_imm(&rgba, x0, y0, crop_w, crop_h).to_image();
let cropped_mask = imageops::crop_imm(mask, x0, y0, crop_w, crop_h).to_image();

#[cfg(target_os = "android")]
let max_dim_limit: u32 = 512;
#[cfg(not(target_os = "android"))]
let max_dim_limit: u32 = 768;
let needs_downscale = crop_w > max_dim_limit || crop_h > max_dim_limit;

Expand Down Expand Up @@ -871,9 +941,14 @@
let mut result_inf = RgbaImage::new(fw, fh);
for y in 0..fh {
for x in 0..fw {
let r = output_tensor[[0, 0, y as usize, x as usize]].clamp(0.0, 255.0) as u8;
let g = output_tensor[[0, 1, y as usize, x as usize]].clamp(0.0, 255.0) as u8;
let b = output_tensor[[0, 2, y as usize, x as usize]].clamp(0.0, 255.0) as u8;
#[cfg(target_os = "android")]
let multiplier = 255.0;
#[cfg(not(target_os = "android"))]
let multiplier = 1.0;

let r = (output_tensor[[0, 0, y as usize, x as usize]] * multiplier).clamp(0.0, 255.0) as u8;
let g = (output_tensor[[0, 1, y as usize, x as usize]] * multiplier).clamp(0.0, 255.0) as u8;
let b = (output_tensor[[0, 2, y as usize, x as usize]] * multiplier).clamp(0.0, 255.0) as u8;
result_inf.put_pixel(x, y, Rgba([r, g, b, 255]));
}
}
Expand Down Expand Up @@ -924,32 +999,70 @@
let (actual_width, actual_height) = rgb_image.dimensions();
let raw_pixels = rgb_image.as_raw();

let mut input_tensor: Array<u8, _> =
Array::zeros((1, 3, SAM_INPUT_SIZE as usize, SAM_INPUT_SIZE as usize));
#[cfg(target_os = "android")]
{
let mut input_tensor = Array::<f32, ndarray::Ix3>::zeros((
SAM_INPUT_SIZE as usize,
SAM_INPUT_SIZE as usize,
3
));

let w_usize = actual_width as usize;
for y in 0..(actual_height as usize) {
for x in 0..w_usize {
let idx = (y * w_usize + x) * 3;
input_tensor[[0, 0, y, x]] = raw_pixels[idx];
input_tensor[[0, 1, y, x]] = raw_pixels[idx + 1];
input_tensor[[0, 2, y, x]] = raw_pixels[idx + 2];
let w_usize = actual_width as usize;
for y in 0..actual_height {
let y_idx = y as usize;
let y_offset = y_idx * w_usize;
for x in 0..actual_width {
let x_idx = x as usize;
let idx = (y_offset + x_idx) * 3;

input_tensor[[y_idx, x_idx, 0]] = raw_pixels[idx] as f32;
input_tensor[[y_idx, x_idx, 1]] = raw_pixels[idx + 1] as f32;
input_tensor[[y_idx, x_idx, 2]] = raw_pixels[idx + 2] as f32;
}
}

let input_tensor_dyn = input_tensor.into_dyn().as_standard_layout().into_owned();
let input_tensor_ort = Tensor::from_array(input_tensor_dyn)?;
let mut session = encoder.lock().unwrap();

let outputs = session.run(ort::inputs!["input_image" => input_tensor_ort])?;
let embeddings = outputs[0].try_extract_array::<f32>()?.to_owned();

Ok(ImageEmbeddings {
path_hash: "".to_string(),
embeddings: embeddings.into_dyn(),
original_size: (orig_width, orig_height),
})
}
#[cfg(not(target_os = "android"))]
{
let mut input_tensor: Array<u8, _> =
Array::zeros((1, 3, SAM_INPUT_SIZE as usize, SAM_INPUT_SIZE as usize));

let w_usize = actual_width as usize;
for y in 0..(actual_height as usize) {
for x in 0..w_usize {
let idx = (y * w_usize + x) * 3;
input_tensor[[0, 0, y, x]] = raw_pixels[idx];
input_tensor[[0, 1, y, x]] = raw_pixels[idx + 1];
input_tensor[[0, 2, y, x]] = raw_pixels[idx + 2];
}
}

let input_tensor_dyn = input_tensor.into_dyn();
let input_values = input_tensor_dyn.as_standard_layout();
let input_tensor_ort = Tensor::from_array(input_values.into_owned())?;
let mut session = encoder.lock().unwrap();
let outputs = session.run(ort::inputs![input_tensor_ort])?;

let embeddings = outputs[0].try_extract_array::<f32>()?.to_owned();

Ok(ImageEmbeddings {
path_hash: "".to_string(),
embeddings: embeddings.into_dyn(),
original_size: (orig_width, orig_height),
})
let input_tensor_dyn = input_tensor.into_dyn();
let input_values = input_tensor_dyn.as_standard_layout();
let input_tensor_ort = Tensor::from_array(input_values.into_owned())?;
let mut session = encoder.lock().unwrap();
let outputs = session.run(ort::inputs![input_tensor_ort])?;

let embeddings = outputs[0].try_extract_array::<f32>()?.to_owned();

Ok(ImageEmbeddings {
path_hash: "".to_string(),
embeddings: embeddings.into_dyn(),
original_size: (orig_width, orig_height),
})
}
}

pub fn run_sam_decoder(
Expand Down
Loading