From 8d6f2f49f962e73302180f15e9fd4b1333da105f Mon Sep 17 00:00:00 2001 From: Abhirupa Ghosh <100681585+AbhirupaGhosh@users.noreply.github.com> Date: Fri, 13 Mar 2026 13:52:35 -0600 Subject: [PATCH 1/4] Add computeFeatureImprovement function for feature analysis This function computes feature improvement by reading and processing feature data from parquet files, applying various transformations and calculations to derive insights on feature importance and contributions. --- R/feature_rescoring.R | 73 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 R/feature_rescoring.R diff --git a/R/feature_rescoring.R b/R/feature_rescoring.R new file mode 100644 index 0000000..23d27a7 --- /dev/null +++ b/R/feature_rescoring.R @@ -0,0 +1,73 @@ +computeFeatureImprovement <- function( + all_feature_parquet, + feature_cluster_parquet +) { + stopifnot(file.exists(all_feature_parquet)) +feature_cluster <- arrow::read_parquet(normalizePath(feature_cluster_parquet)) + + + features_rescored <- arrow::read_parquet(normalizePath(all_feature_parquet)) |> + dplyr::select( + output_prefix, + drug_label, drug_or_class, shuffled, pca, + feature_type, feature_subtype, Variable, + Importance, Sign + ) |> + dplyr::filter(!pca) |> + dplyr::mutate( + Variable = dplyr::case_when( + feature_type == "domains" ~ sub("_.+$", "", Variable), + feature_type == "proteins" ~ sub("fig.", "fig|", Variable, fixed = TRUE), + TRUE ~ Variable + ) + ) |> + dplyr::group_by(output_prefix) |> + dplyr::mutate( + rank = dplyr::dense_rank(dplyr::desc(Importance)), + denom = sum(Importance, na.rm = TRUE), + contribution = dplyr::if_else(denom > 0, Importance / denom, 0), + + # Safer rescaling within each output_prefix + min_imp = min(Importance, na.rm = TRUE), + max_imp = max(Importance, na.rm = TRUE), + range_imp = max_imp - min_imp, + rescaled = dplyr::if_else(range_imp > 0, (Importance - min_imp) / range_imp, 0) + ) |> + dplyr::ungroup() |> + dplyr::group_by(drug_label, drug_or_class, feature_type, feature_subtype, shuffled, Variable) |> + dplyr::mutate(median_datatype = stats::median(rank, na.rm = TRUE)) |> + dplyr::ungroup() |> + dplyr::group_by(drug_label, drug_or_class, feature_type, shuffled, Variable) |> + dplyr::mutate(median_scale = stats::median(median_datatype, na.rm = TRUE)) |> + dplyr::ungroup() |> + # NOTE: correct join syntax (no quotes in join_by) + dplyr::left_join(feature_cluster, by = dplyr::join_by(Variable == feature)) |> + dplyr::group_by(drug_label, drug_or_class, shuffled, cluster) |> + dplyr::mutate( + median_drug_or_class = stats::median(rank, na.rm = TRUE), + count_scales_for_cluster = dplyr::n_distinct(feature_type), + feature_types_csv = paste(sort(unique(feature_type)), collapse = ","), + lowest_contri = min(contribution, na.rm = TRUE), + highest_contri = max(contribution, na.rm = TRUE) + ) |> + dplyr::ungroup() |> + dplyr::group_by(drug_label, drug_or_class, cluster) |> + dplyr::mutate( + shuffled_rank = dplyr::if_else(shuffled, median_drug_or_class, NA_real_), + non_shuffled_rank = dplyr::if_else(!shuffled, median_drug_or_class, NA_real_), + + # If non-shuffled is missing -> NA (no evidence). If shuffled missing -> +Inf improvement. + improvement = dplyr::case_when( + !is.na(non_shuffled_rank) ~ tidyr::replace_na(shuffled_rank, Inf) - non_shuffled_rank, + TRUE ~ NA_real_ + ), + + # "Good" if non-shuffled exists AND (non-shuffled < shuffled OR shuffled is missing) + good_feature = !is.na(non_shuffled_rank) & improvement > 0 + ) |> + dplyr::ungroup() + + return(features_rescored) + + } + From 729c8370c7fe0d8c07d7ab348ce0b4dd70be1d91 Mon Sep 17 00:00:00 2001 From: AbhirupaGhosh Date: Fri, 13 Mar 2026 19:57:53 +0000 Subject: [PATCH 2/4] Style code (GHA) --- R/core_ml.R | 217 ++++++++++++------- R/feature_rescoring.R | 34 ++- R/generate_matrices_ml.R | 197 ++++++++++------- R/globals.R | 2 - R/plot_ml.R | 1 - R/prep_ml.R | 6 +- R/run_ML.R | 450 +++++++++++++++++++++------------------ R/run_ml_pipeline.R | 41 ++-- vignettes/intro.Rmd | 30 +-- 9 files changed, 559 insertions(+), 419 deletions(-) diff --git a/R/core_ml.R b/R/core_ml.R index db874db..1fb0c1b 100644 --- a/R/core_ml.R +++ b/R/core_ml.R @@ -73,7 +73,8 @@ NULL #' @return An `rsplit` object #' @export splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280) { - .checkArgTibble(ml_input_tibble, ml = TRUE); .checkArgSplit(split) + .checkArgTibble(ml_input_tibble, ml = TRUE) + .checkArgSplit(split) .checkArgSeed(seed) set.seed(seed) @@ -85,7 +86,7 @@ splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280 # If in CV mode: # Still retain a stratified testing holdout purely for final reporting metrics; # CV is only performed on the training portion. - prop_train_for_holdout <- 0.8 # 80 percent train, 20 percent reserved test + prop_train_for_holdout <- 0.8 # 80 percent train, 20 percent reserved test data_split <- rsample::initial_split( ml_input_tibble, prop = prop_train_for_holdout, @@ -115,7 +116,8 @@ splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280 #' @return A `recipe` object #' @export buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { - .checkArgTibble(train_data, ml = TRUE); .checkArgUsePCA(use_pca) + .checkArgTibble(train_data, ml = TRUE) + .checkArgUsePCA(use_pca) .checkArgPCAThreshold(pca_threshold) target_var <- .getTargetVarName(train_data) |> as.character() @@ -124,8 +126,10 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { nm <- names(train_data) id_cols <- setdiff(nm[grepl("^genome", nm)], target_var) - rec <- recipes::recipe(formula = stats::reformulate(".", response = target_var), - data = train_data) + rec <- recipes::recipe( + formula = stats::reformulate(".", response = target_var), + data = train_data + ) # Only update roles if we actually have ID columns to mark as metadata if (length(id_cols) > 0) { @@ -146,7 +150,6 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { } - #' buildLRModel() #' #' Builds a logistic regression model. @@ -158,13 +161,17 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { buildLRModel <- function(multi_class = FALSE) { .checkArgMultiClass(multi_class) - if(!multi_class) { - lr_mod <- parsnip::logistic_reg(penalty = hardhat::tune(), - mixture = hardhat::tune()) |> + if (!multi_class) { + lr_mod <- parsnip::logistic_reg( + penalty = hardhat::tune(), + mixture = hardhat::tune() + ) |> parsnip::set_engine(engine = "glmnet") - } else if(multi_class) { - lr_mod <- parsnip::multinom_reg(penalty = hardhat::tune(), - mixture = hardhat::tune()) |> + } else if (multi_class) { + lr_mod <- parsnip::multinom_reg( + penalty = hardhat::tune(), + mixture = hardhat::tune() + ) |> parsnip::set_engine(engine = "glmnet") } @@ -181,9 +188,11 @@ buildLRModel <- function(multi_class = FALSE) { #' @return A `workflow` object #' @export buildWflow <- function(parsnip_mod, recipe) { - .checkArgParsnipMod(parsnip_mod); .checkArgRecipe(recipe) + .checkArgParsnipMod(parsnip_mod) + .checkArgRecipe(recipe) - wflow <- workflows::workflow() |> workflows::add_model(parsnip_mod) |> + wflow <- workflows::workflow() |> + workflows::add_model(parsnip_mod) |> workflows::add_recipe(recipe) return(wflow) @@ -203,21 +212,21 @@ buildWflow <- function(parsnip_mod, recipe) { #' @return A logistic regression tuning grid as a tibble #' @export buildTuningGrid <- function( - model = "LR", - penalty_vec = 10^seq(-4, -1, length.out = 10), - mix_vec = 0:5 / 5 + model = "LR", + penalty_vec = 10^seq(-4, -1, length.out = 10), + mix_vec = 0:5 / 5 ) { .checkArgModel(model) - + if (model == "LR") { .checkArgPenaltyVec(penalty_vec) .checkArgMixVec(mix_vec) - + penalty <- rep(penalty_vec, each = length(mix_vec)) mixture <- rep(mix_vec, length(penalty_vec)) grid <- tibble::tibble(penalty, mixture) } - + return(grid) } @@ -237,13 +246,14 @@ buildTuningGrid <- function( #' @export tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"), n_fold = 5) { - .checkArgTibble(grid); .checkArgWflow(wflow) + .checkArgTibble(grid) + .checkArgWflow(wflow) .checkArgDataSplit(data_split) split_class <- class(data_split)[1] # Always do CV on the training portion of the split - train_df <- rsample::training(data_split) + train_df <- rsample::training(data_split) target_var <- .getTargetVarName(train_df) if (identical(split_class, "initial_split")) { @@ -259,9 +269,9 @@ tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"), tune_res <- tune::tune_grid( wflow, resamples = resamples, - grid = grid, - control = tune::control_grid(save_pred = TRUE), - metrics = yardstick::metric_set( + grid = grid, + control = tune::control_grid(save_pred = TRUE), + metrics = yardstick::metric_set( yardstick::f_meas, yardstick::pr_auc, yardstick::spec, @@ -286,7 +296,8 @@ tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"), #' @return Best model workflow #' @export selectBestModel <- function(tune_res, wflow, select_best_metric = "mcc") { - .checkArgTuneRes(tune_res); .checkArgWflow(wflow) + .checkArgTuneRes(tune_res) + .checkArgWflow(wflow) .checkArgSelectBestMetric(select_best_metric) best_mod <- tune::select_best(tune_res, metric = select_best_metric) @@ -306,7 +317,8 @@ selectBestModel <- function(tune_res, wflow, select_best_metric = "mcc") { #' @return Best model fit #' @export fitBestModel <- function(final_mod, train_data) { - .checkArgWflow(final_mod); .checkArgTibble(train_data, ml = TRUE) + .checkArgWflow(final_mod) + .checkArgTibble(train_data, ml = TRUE) fit <- final_mod |> parsnip::fit(data = train_data) @@ -324,8 +336,7 @@ fitBestModel <- function(final_mod, train_data) { model <- class(fit$fit$actions$model$spec)[1] - if(model %in% c("logistic_reg", "multinom_reg")) { - + if (model %in% c("logistic_reg", "multinom_reg")) { penalty <- fit$fit$fit$spec$args$penalty mixture <- tryCatch( @@ -334,7 +345,6 @@ fitBestModel <- function(final_mod, train_data) { ) tibble::tibble(penalty = penalty, mixture = mixture) - } else { stop("The `fit` object provided must correspond to 'logistic_reg' or 'multinom_reg'.") } @@ -353,7 +363,8 @@ fitBestModel <- function(final_mod, train_data) { #' labels #' @export predictML <- function(fit, test_data) { - .checkArgWflow(fit); .checkArgTibble(test_data, ml = TRUE) + .checkArgWflow(fit) + .checkArgTibble(test_data, ml = TRUE) test_data_plus_predictions <- parsnip::augment(fit, test_data) @@ -396,7 +407,8 @@ getConfusionMatrix <- function(test_data_plus_predictions) { mcc <- test_data_plus_predictions |> yardstick::mcc(truth = !!target_var, estimate = .pred_class) |> - dplyr::select(.estimate) |> as.numeric() + dplyr::select(.estimate) |> + as.numeric() nmcc <- (mcc + 1) / 2 @@ -413,15 +425,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateF1 <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } f1 <- test_data_plus_predictions |> - yardstick::f_meas(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> dplyr::select(.estimate) |> as.numeric() |> + yardstick::f_meas( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> round(2) return(f1) @@ -437,16 +455,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateAUPRC <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } auprc <- test_data_plus_predictions |> yardstick::pr_auc( - truth = genome_drug.resistant_phenotype, .pred_Resistant) |> - dplyr::select(.estimate) |> as.numeric() |> round(2) + truth = genome_drug.resistant_phenotype, .pred_Resistant + ) |> + dplyr::select(.estimate) |> + as.numeric() |> + round(2) return(auprc) } @@ -461,26 +484,33 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateLog2APOP <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } auprc <- .calculateAUPRC(test_data_plus_predictions) prior <- sum( - test_data_plus_predictions$genome_drug.resistant_phenotype == "Resistant") / + test_data_plus_predictions$genome_drug.resistant_phenotype == "Resistant" + ) / nrow(test_data_plus_predictions) - if(prior > 0.3 && prior < 0.7) { - warning(paste("Classes are roughly balanced.", - "Calculation of log2(AUPRC/prior) may be inappropriate.")) - } else if(prior >= 0.7) { - warning(paste("Classes are imbalanced toward the resistant phenotype.", - "Calculation of log2(AUPRC/prior) may be inappropriate.")) + if (prior > 0.3 && prior < 0.7) { + warning(paste( + "Classes are roughly balanced.", + "Calculation of log2(AUPRC/prior) may be inappropriate." + )) + } else if (prior >= 0.7) { + warning(paste( + "Classes are imbalanced toward the resistant phenotype.", + "Calculation of log2(AUPRC/prior) may be inappropriate." + )) } - log2_apop <- log2(auprc/prior) |> round(2) + log2_apop <- log2(auprc / prior) |> round(2) return(log2_apop) } @@ -495,16 +525,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateBalAcc <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } bal_acc <- test_data_plus_predictions |> yardstick::bal_accuracy( - truth = genome_drug.resistant_phenotype, estimate = .pred_class) |> - dplyr::select(.estimate) |> as.numeric() |> round(2) + truth = genome_drug.resistant_phenotype, estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> + round(2) return(bal_acc) } @@ -519,15 +554,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateSensitivity <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } sens <- test_data_plus_predictions |> - yardstick::sens(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> dplyr::select(.estimate) |> as.numeric() |> + yardstick::sens( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> round(2) return(sens) @@ -543,15 +584,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateSpecificity <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } spec <- test_data_plus_predictions |> - yardstick::spec(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> dplyr::select(.estimate) |> as.numeric() |> + yardstick::spec( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> round(2) return(spec) @@ -598,30 +645,36 @@ calculateEvalMets <- function(test_data_plus_predictions) { #' `Importance`, and a column for `Sign` (or, for multi-class, a tibble with #' per-class columns of importance scores for each `Variable`) #' @export -extractTopFeats <- function(fit, prop_vi_top_feats = c(0, 1), - n_top_feats = NA) { +extractTopFeats <- function( + fit, prop_vi_top_feats = c(0, 1), + n_top_feats = NA +) { .checkArgWflow(fit) - if(!is.na(n_top_feats)) {prop_vi_top_feats <- NA} + if (!is.na(n_top_feats)) { + prop_vi_top_feats <- NA + } # Arg checking for every permutation of `prop_vi_top_feats` and `n_top_feats` - if(is.na(n_top_feats) & any(!is.na(prop_vi_top_feats))) { + if (is.na(n_top_feats) & any(!is.na(prop_vi_top_feats))) { .checkArgPropVITopFeats(prop_vi_top_feats) - } else if(any(is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { + } else if (any(is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { .checkArgNTopFeats(n_top_feats) - } else if(any(!is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { + } else if (any(!is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { stop("Set either `n_top_feats` or `prop_vi_top_feats` to `NA` but not both.") - } else if(any(is.na(prop_vi_top_feats)) & is.na(n_top_feats)) { + } else if (any(is.na(prop_vi_top_feats)) & is.na(n_top_feats)) { stop("Please specify either `n_top_feats` or `prop_vi_top_feats`.") } - feats_arranged <- fit |> workflowsets::extract_fit_parsnip() |> vip::vi() |> + feats_arranged <- fit |> + workflowsets::extract_fit_parsnip() |> + vip::vi() |> dplyr::arrange(dplyr::desc(Importance)) - if(!is.na(n_top_feats)) { + if (!is.na(n_top_feats)) { top_feats_and_VIs <- feats_arranged |> dplyr::slice(1:n_top_feats) - } else if(any(!is.na(prop_vi_top_feats))) { + } else if (any(!is.na(prop_vi_top_feats))) { cum_vi_lower <- prop_vi_top_feats[1] * sum(feats_arranged$Importance) cum_vi_upper <- prop_vi_top_feats[2] * sum(feats_arranged$Importance) @@ -638,9 +691,11 @@ extractTopFeats <- function(fit, prop_vi_top_feats = c(0, 1), # Take a different approach if using multi-class (the previous code would give # a less meaningful result). - if(class(fit$fit$actions$model$spec)[1] == "multinom_reg") { - warning(paste("Extracting top features from a multi-class model.", - "The `prop_vi_top_feats` and `n_top_feats` arguments do not apply.")) + if (class(fit$fit$actions$model$spec)[1] == "multinom_reg") { + warning(paste( + "Extracting top features from a multi-class model.", + "The `prop_vi_top_feats` and `n_top_feats` arguments do not apply." + )) fit_penalty <- .getFitHps(fit)["penalty"] |> as.numeric() glmnet_fit <- parsnip::extract_fit_engine(fit) diff --git a/R/feature_rescoring.R b/R/feature_rescoring.R index 23d27a7..5c4fa0b 100644 --- a/R/feature_rescoring.R +++ b/R/feature_rescoring.R @@ -3,12 +3,12 @@ computeFeatureImprovement <- function( feature_cluster_parquet ) { stopifnot(file.exists(all_feature_parquet)) -feature_cluster <- arrow::read_parquet(normalizePath(feature_cluster_parquet)) - - - features_rescored <- arrow::read_parquet(normalizePath(all_feature_parquet)) |> + feature_cluster <- arrow::read_parquet(normalizePath(feature_cluster_parquet)) + + + features_rescored <- arrow::read_parquet(normalizePath(all_feature_parquet)) |> dplyr::select( - output_prefix, + output_prefix, drug_label, drug_or_class, shuffled, pca, feature_type, feature_subtype, Variable, Importance, Sign @@ -16,22 +16,22 @@ feature_cluster <- arrow::read_parquet(normalizePath(feature_cluster_parquet)) dplyr::filter(!pca) |> dplyr::mutate( Variable = dplyr::case_when( - feature_type == "domains" ~ sub("_.+$", "", Variable), + feature_type == "domains" ~ sub("_.+$", "", Variable), feature_type == "proteins" ~ sub("fig.", "fig|", Variable, fixed = TRUE), TRUE ~ Variable ) ) |> dplyr::group_by(output_prefix) |> dplyr::mutate( - rank = dplyr::dense_rank(dplyr::desc(Importance)), + rank = dplyr::dense_rank(dplyr::desc(Importance)), denom = sum(Importance, na.rm = TRUE), contribution = dplyr::if_else(denom > 0, Importance / denom, 0), # Safer rescaling within each output_prefix - min_imp = min(Importance, na.rm = TRUE), - max_imp = max(Importance, na.rm = TRUE), + min_imp = min(Importance, na.rm = TRUE), + max_imp = max(Importance, na.rm = TRUE), range_imp = max_imp - min_imp, - rescaled = dplyr::if_else(range_imp > 0, (Importance - min_imp) / range_imp, 0) + rescaled = dplyr::if_else(range_imp > 0, (Importance - min_imp) / range_imp, 0) ) |> dplyr::ungroup() |> dplyr::group_by(drug_label, drug_or_class, feature_type, feature_subtype, shuffled, Variable) |> @@ -53,21 +53,19 @@ feature_cluster <- arrow::read_parquet(normalizePath(feature_cluster_parquet)) dplyr::ungroup() |> dplyr::group_by(drug_label, drug_or_class, cluster) |> dplyr::mutate( - shuffled_rank = dplyr::if_else(shuffled, median_drug_or_class, NA_real_), + shuffled_rank = dplyr::if_else(shuffled, median_drug_or_class, NA_real_), non_shuffled_rank = dplyr::if_else(!shuffled, median_drug_or_class, NA_real_), # If non-shuffled is missing -> NA (no evidence). If shuffled missing -> +Inf improvement. improvement = dplyr::case_when( !is.na(non_shuffled_rank) ~ tidyr::replace_na(shuffled_rank, Inf) - non_shuffled_rank, - TRUE ~ NA_real_ + TRUE ~ NA_real_ ), # "Good" if non-shuffled exists AND (non-shuffled < shuffled OR shuffled is missing) good_feature = !is.na(non_shuffled_rank) & improvement > 0 ) |> - dplyr::ungroup() - - return(features_rescored) - - } - + dplyr::ungroup() + + return(features_rescored) +} diff --git a/R/generate_matrices_ml.R b/R/generate_matrices_ml.R index bb1dc68..b19beef 100644 --- a/R/generate_matrices_ml.R +++ b/R/generate_matrices_ml.R @@ -156,7 +156,6 @@ skipImbalancedMatrix <- function(genome_ids, split, stratify_by = NULL, verbosity = c("minimal", "debug")) { - verbosity <- match.arg(verbosity) log <- .make_logger(verbosity) @@ -197,8 +196,10 @@ skipImbalancedMatrix <- function(genome_ids, if (!dir.exists(matrix_path)) dir.create(matrix_path, recursive = TRUE) log("info", paste0("Matrix output directory: ", matrix_path)) - log("debug", paste0("Stratification: ", - ifelse(is.null(stratify_column), "None", stratify_column))) + log("debug", paste0( + "Stratification: ", + ifelse(is.null(stratify_column), "None", stratify_column) + )) # Feature and matrix types feature_types <- list( @@ -220,9 +221,11 @@ skipImbalancedMatrix <- function(genome_ids, # Safe DBI-quoting quote_condition <- function(group_cols, group_values, con) { - ids <- vapply(group_cols, - function(col) DBI::dbQuoteIdentifier(con, col), - character(1)) + ids <- vapply( + group_cols, + function(col) DBI::dbQuoteIdentifier(con, col), + character(1) + ) vals <- vapply( group_cols, function(col) { @@ -256,7 +259,6 @@ skipImbalancedMatrix <- function(genome_ids, log("debug", paste0("Found ", nrow(all_groups), " groups for type: ", group_type)) for (i in seq_len(nrow(all_groups))) { - # New connection for this group con <- DBI::dbConnect(duckdb::duckdb(), parquet_duckdb_path) @@ -268,13 +270,14 @@ skipImbalancedMatrix <- function(genome_ids, condition_string <- quote_condition(group_cols, group_values, con) # Strat filter - strat_filter <- if (!is.null(stratify_column)) + strat_filter <- if (!is.null(stratify_column)) { sprintf("AND \"%s\" IS NOT NULL AND \"%s\" != ''", stratify_column, stratify_column) - else "" + } else { + "" + } # Genome selection logic if (group_type %in% c("drug_class", "drug_class_year", "drug_class_country")) { - genome_ids <- DBI::dbGetQuery(con, sprintf(" WITH class_phenotypes AS ( SELECT \"genome_drug.genome_id\" AS genome_id, @@ -290,7 +293,6 @@ skipImbalancedMatrix <- function(genome_ids, FROM class_phenotypes WHERE any_resistant = 1 OR all_susceptible = 1 ", condition_string))[[1]] - } else { genome_ids <- DBI::dbGetQuery(con, sprintf(" SELECT DISTINCT \"genome_drug.genome_id\" @@ -310,19 +312,24 @@ skipImbalancedMatrix <- function(genome_ids, ", condition_string)) phenotype_summary <- paste( - apply(phenotype_counts_all, 1, - function(row) paste0(row["phenotype"], "=", row["count"])), + apply( + phenotype_counts_all, 1, + function(row) paste0(row["phenotype"], "=", row["count"]) + ), collapse = "; " ) # Apply skip logic if (skipImbalancedMatrix(genome_ids, phenotype_counts_all, n_fold, split, - verbosity = verbosity)) { - + verbosity = verbosity + )) { readr::write_lines( - sprintf("%s\tToo few samples for CV/split\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), - log_path, append = TRUE + sprintf( + "%s\tToo few samples for CV/split\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), + log_path, + append = TRUE ) DBI::dbDisconnect(con, shutdown = FALSE) @@ -331,9 +338,12 @@ skipImbalancedMatrix <- function(genome_ids, if (length(genome_ids) < 40) { readr::write_lines( - sprintf("%s\tToo few observations\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), - log_path, append = TRUE + sprintf( + "%s\tToo few observations\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), + log_path, + append = TRUE ) DBI::dbDisconnect(con, shutdown = FALSE) @@ -351,9 +361,12 @@ skipImbalancedMatrix <- function(genome_ids, if (nrow(phen2) < 2) { readr::write_lines( - sprintf("%s\tOnly one phenotype class\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), - log_path, append = TRUE + sprintf( + "%s\tOnly one phenotype class\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), + log_path, + append = TRUE ) DBI::dbDisconnect(con, shutdown = FALSE) @@ -363,13 +376,14 @@ skipImbalancedMatrix <- function(genome_ids, # Create selected_genomes DBI::dbExecute(con, "CREATE OR REPLACE TEMP TABLE selected_genomes (genome_id VARCHAR)") DBI::dbWriteTable(con, "selected_genomes", - data.frame(genome_id = genome_ids), append = TRUE) + data.frame(genome_id = genome_ids), + append = TRUE + ) # Feature and matrix generation steps for (ftype in names(feature_types)) { - fview <- feature_types[[ftype]]$view - fid <- feature_types[[ftype]]$id_col + fid <- feature_types[[ftype]]$id_col # binary view DBI::dbExecute(con, sprintf(" @@ -389,13 +403,14 @@ skipImbalancedMatrix <- function(genome_ids, } for (mtype in names(matrix_types)) { - binary_only <- matrix_types[[mtype]]$binary_only if (ftype == "struct" && !binary_only) next - mview <- sprintf("%s_%s", ftype, - ifelse(grepl("binary", mtype), "binary", "counts")) - value_col <- matrix_types[[mtype]]$value_col + mview <- sprintf( + "%s_%s", ftype, + ifelse(grepl("binary", mtype), "binary", "counts") + ) + value_col <- matrix_types[[mtype]]$value_col filter_clause <- matrix_types[[mtype]]$filter # select features with non-zero variance @@ -409,29 +424,38 @@ skipImbalancedMatrix <- function(genome_ids, keep_features <- DBI::dbGetQuery(con, keep_query)[["feature_id"]] if (length(keep_features) == 0) { - log("info", paste0("All features filtered for ", - ftype, " - ", mtype, " - ", group_label)) + log("info", paste0( + "All features filtered for ", + ftype, " - ", mtype, " - ", group_label + )) next } - DBI::dbExecute(con, - "CREATE OR REPLACE TEMP TABLE keep_features (feature_id VARCHAR)") + DBI::dbExecute( + con, + "CREATE OR REPLACE TEMP TABLE keep_features (feature_id VARCHAR)" + ) DBI::dbWriteTable(con, - "keep_features", - data.frame(feature_id = keep_features), - append = TRUE) + "keep_features", + data.frame(feature_id = keep_features), + append = TRUE + ) mtype_label <- matrix_types[[mtype]]$label - long_out_path <- file.path(matrix_path, - sprintf("%s_%s_%s_%s_%s_sparse.parquet", - bug, group_type, group_label, ftype, mtype_label)) + long_out_path <- file.path( + matrix_path, + sprintf( + "%s_%s_%s_%s_%s_sparse.parquet", + bug, group_type, group_label, ftype, mtype_label + ) + ) long_out_path_sql <- gsub("\\\\", "/", long_out_path) # phenotype case phenotype_case <- if (group_type %in% - c("drug_class", "drug_class_year", "drug_class_country")) { + c("drug_class", "drug_class_year", "drug_class_country")) { " CASE WHEN MAX(CASE WHEN f.\"genome_drug.resistant_phenotype\"='Resistant' @@ -451,13 +475,20 @@ skipImbalancedMatrix <- function(genome_ids, " } - strat_col_select <- if (!is.null(stratify_by)) - sprintf(", f.\"%s\"", stratify_column) else "" + strat_col_select <- if (!is.null(stratify_by)) { + sprintf(", f.\"%s\"", stratify_column) + } else { + "" + } - strat_col_group <- if (!is.null(stratify_by)) - sprintf(", f.\"%s\"", stratify_column) else "" + strat_col_group <- if (!is.null(stratify_by)) { + sprintf(", f.\"%s\"", stratify_column) + } else { + "" + } - copy_sql <- sprintf(" + copy_sql <- sprintf( + " COPY ( SELECT f.\"genome_drug.genome_id\" AS genome_id, @@ -478,18 +509,21 @@ skipImbalancedMatrix <- function(genome_ids, TO '%s' (FORMAT 'parquet', COMPRESSION 'zstd') ", - fid, value_col, phenotype_case, strat_col_select, - mview, fid, condition_string, - strat_filter, fid, strat_col_group, fid, - long_out_path_sql) + fid, value_col, phenotype_case, strat_col_select, + mview, fid, condition_string, + strat_filter, fid, strat_col_group, fid, + long_out_path_sql + ) ok <- try(DBI::dbExecute(con, copy_sql), silent = TRUE) # On copy failure, log + continue without stopping entire pipeline if (inherits(ok, "try-error")) { readr::write_lines( - sprintf("%s\tCOPY_failed\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), + sprintf( + "%s\tCOPY_failed\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), log_path, append = TRUE ) @@ -530,7 +564,7 @@ skipImbalancedMatrix <- function(genome_ids, # Normalize paths to forward slashes for consistency matrix_path <- gsub("\\\\", "/", file.path(path, paste0("matrix_", stratify_by))) - LOO_path <- gsub("\\\\", "/", file.path(path, paste0("LOO_matrix_", stratify_by))) + LOO_path <- gsub("\\\\", "/", file.path(path, paste0("LOO_matrix_", stratify_by))) if (!dir.exists(matrix_path)) { log("info", paste0("The matrix directory ", matrix_path, " does not exist.")) @@ -626,9 +660,11 @@ skipImbalancedMatrix <- function(genome_ids, out_file <- gsub("\\\\", "/", file.path( LOO_path, - paste0(sub_prefix, "_", stratify_by, "_", - drug_class, "_leaveout_", leave_one_out, "_", - sub_feature, "_sparse.parquet") + paste0( + sub_prefix, "_", stratify_by, "_", + drug_class, "_leaveout_", leave_one_out, "_", + sub_feature, "_sparse.parquet" + ) )) arrow::write_parquet(combined, out_file) created <<- c(created, out_file) @@ -702,7 +738,7 @@ skipImbalancedMatrix <- function(genome_ids, # Build one matrix per feature type and matrix type for (ftype in names(feature_types)) { fview <- feature_types[[ftype]]$view - fid <- feature_types[[ftype]]$id_col + fid <- feature_types[[ftype]]$id_col for (mtype in names(matrix_types)) { binary_only <- matrix_types[[mtype]]$binary_only @@ -722,8 +758,9 @@ skipImbalancedMatrix <- function(genome_ids, # Selected genomes DBI::dbExecute(con, "CREATE OR REPLACE TEMP TABLE selected_genomes (genome_id VARCHAR)") DBI::dbWriteTable(con, "selected_genomes", - data.frame(genome_id = genomes_to_keep), - append = TRUE) + data.frame(genome_id = genomes_to_keep), + append = TRUE + ) # Binary view DBI::dbExecute(con, sprintf(" @@ -763,13 +800,15 @@ skipImbalancedMatrix <- function(genome_ids, DBI::dbExecute(con, "CREATE OR REPLACE TEMP TABLE keep_features (feature_id VARCHAR)") DBI::dbWriteTable(con, "keep_features", - data.frame(feature_id = keep_features), - append = TRUE) + data.frame(feature_id = keep_features), + append = TRUE + ) + - - copy_sql <- sprintf(" + copy_sql <- sprintf( + " COPY ( - SELECT + SELECT f.\"genome_drug.genome_id\" AS genome_id, %s AS feature_id, MAX(CAST(%s AS DOUBLE)) AS value, @@ -779,26 +818,26 @@ skipImbalancedMatrix <- function(genome_ids, JOIN keep_features kf ON %s = kf.feature_id JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\" WHERE resistant_classes <> 'Intermediate' - GROUP BY - f.\"genome_drug.genome_id\", - %s, + GROUP BY + f.\"genome_drug.genome_id\", + %s, resistant_classes - ORDER BY - f.\"genome_drug.genome_id\", + ORDER BY + f.\"genome_drug.genome_id\", %s ) TO '%s' (FORMAT 'parquet', COMPRESSION 'zstd') - ", - fid, # %s -> feature_id expression column name - value_col, # %s -> value column to CAST - mview, # %s -> source view (binary or counts) - fid, # %s -> join to keep_features - fid, # %s -> group by feature id - fid, # %s -> order by feature id - out_file_sql # %s -> destination parquet file + ", + fid, # %s -> feature_id expression column name + value_col, # %s -> value column to CAST + mview, # %s -> source view (binary or counts) + fid, # %s -> join to keep_features + fid, # %s -> group by feature id + fid, # %s -> order by feature id + out_file_sql # %s -> destination parquet file ) - + ok <- try(DBI::dbExecute(con, copy_sql), silent = TRUE) if (inherits(ok, "try-error")) { log("info", paste0("COPY failed for MDR matrix: ", out_file)) diff --git a/R/globals.R b/R/globals.R index a6595d2..131d016 100644 --- a/R/globals.R +++ b/R/globals.R @@ -8,7 +8,6 @@ "_PACKAGE" utils::globalVariables(c( - # Prediction columns from tidymodels ".estimate", ".pred_Resistant", @@ -52,7 +51,6 @@ utils::globalVariables(c( "pair_id", "parts", "phenotype", - "precision", "prefix", "prefix_key", diff --git a/R/plot_ml.R b/R/plot_ml.R index 071e1a1..90121f3 100644 --- a/R/plot_ml.R +++ b/R/plot_ml.R @@ -214,7 +214,6 @@ plotFishers <- function( alpha = 0.05, label_top_n = 5 ) { - required_cols <- c("gene", "adj_p_value", "sig_after_bh") missing_cols <- setdiff(required_cols, colnames(fisher_df)) diff --git a/R/prep_ml.R b/R/prep_ml.R index d47c160..4a5954e 100644 --- a/R/prep_ml.R +++ b/R/prep_ml.R @@ -111,8 +111,10 @@ loadMLInputTibble <- function(parquet_path) { if (exists(".ml_logger")) { log <- .ml_logger("minimal") - log("debug", paste0("ML tibble constructed: ", nrow(ml_input_tibble), - " genomes x ", getNumFeat(ml_input_tibble), " features")) + log("debug", paste0( + "ML tibble constructed: ", nrow(ml_input_tibble), + " genomes x ", getNumFeat(ml_input_tibble), " features" + )) } if (anyDuplicated(dplyr::pull(ml_input_tibble, genome_id)) != 0) { diff --git a/R/run_ML.R b/R/run_ML.R index eba37f8..2ed07e7 100644 --- a/R/run_ML.R +++ b/R/run_ML.R @@ -4,9 +4,11 @@ #' the ML matrices with these new split/CV values instead. #' @noRd .resolveSplitParams <- function(parquet_path, - defaults = list(split = c(0.8, 0), - seed = 5280, - n_fold = 5)) { + defaults = list( + split = c(0.8, 0), + seed = 5280, + n_fold = 5 + )) { # matrix_dir is the directory that contains the parquet files matrix_dir <- normalizePath(dirname(parquet_path)) params_json <- .readMLParameters(matrix_dir) @@ -16,8 +18,8 @@ } list( - split = if (!is.null(params_json$split)) params_json$split else defaults$split, - seed = if (!is.null(params_json$seed)) params_json$seed else defaults$seed, + split = if (!is.null(params_json$split)) params_json$split else defaults$split, + seed = if (!is.null(params_json$seed)) params_json$seed else defaults$seed, n_fold = if (!is.null(params_json$n_fold)) params_json$n_fold else defaults$n_fold ) } @@ -53,8 +55,9 @@ #' #' # LOO analysis stratified by year #' paths_loo <- createMLResultDir("/path/to/results", -#' stratify_by = "year", -#' LOO = TRUE) +#' stratify_by = "year", +#' LOO = TRUE +#' ) #' #' # MDR analysis #' paths_mdr <- createMLResultDir("/path/to/results", MDR = TRUE) @@ -90,16 +93,17 @@ createMLResultDir <- function(path, ) } else { # Determine prefixes (only in non-MDR mode) - full_prefix <- paste0(ifelse(isTRUE(LOO), "LOO_", ""), - ifelse(isTRUE(cross_test), "cross_test_", "")) + full_prefix <- paste0( + ifelse(isTRUE(LOO), "LOO_", ""), + ifelse(isTRUE(cross_test), "cross_test_", "") + ) half_prefix <- ifelse(isTRUE(LOO), "LOO_", "") # Determine suffix suffix <- if (is.null(stratify_by) || identical(stratify_by, "")) { "" } else { - switch( - stratify_by, + switch(stratify_by, "country" = "_country", "year" = "_year", stop("`stratify_by` must be NULL, 'country', or 'year'.") @@ -127,20 +131,20 @@ createMLResultDir <- function(path, return(paths) } - # createAllMLResultDir <- function(path) { - # createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = TRUE) - # createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = TRUE, MDR = FALSE) - # } - # +# createAllMLResultDir <- function(path) { +# createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = TRUE) +# createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = TRUE, MDR = FALSE) +# } +# #' Create machine learning input list #' @@ -174,8 +178,9 @@ createMLResultDir <- function(path, #' #' # Cross-test with year stratification #' inputs_ct <- createMLinputList("/path/to/results", -#' stratify_by = "year", -#' cross_test = TRUE) +#' stratify_by = "year", +#' cross_test = TRUE +#' ) #' #' # MDR analysis #' inputs_mdr <- createMLinputList("/path/to/results", MDR = TRUE) @@ -187,10 +192,10 @@ createMLinputList <- function(path, LOO = FALSE, MDR = FALSE, cross_test = FALSE) { - # Validate inputs - if (!is.character(path) || length(path) != 1 || is.na(path)) + if (!is.character(path) || length(path) != 1 || is.na(path)) { stop("`path` must be a valid file path string.") + } path <- normalizePath(path) @@ -225,21 +230,17 @@ createMLinputList <- function(path, # Multi-drug resistance models # ============================ if (MDR) { - parsed <- tibble::tibble(ref_file = files_vec) |> dplyr::mutate( parts = stringr::str_split(basename(ref_file), "_"), - species = purrr::map_chr(parts, ~ .x[1]), - mdr_tag = purrr::map_chr(parts, ~ .x[2]), # always "MDR" + mdr_tag = purrr::map_chr(parts, ~ .x[2]), # always "MDR" phenotype = purrr::map_chr(parts, ~ paste(.x[3:4], collapse = "_")), # Feature is 5th + 6th tokens feature_type = purrr::map_chr(parts, ~ .x[5]), feature_subtype = purrr::map_chr(parts, ~ stringr::str_remove(.x[6], "_sparse.parquet")), - feature = purrr::map2_chr(feature_type, feature_subtype, paste, sep = "_"), - output_prefix = paste0("MDR_", phenotype, "_", feature) ) @@ -247,38 +248,43 @@ createMLinputList <- function(path, dplyr::mutate( test_file = NA_character_, matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) - # ============================ - # For all other modeling types - # ============================ + # ============================ + # For all other modeling types + # ============================ } else { - parsed <- tibble::tibble(ref_file = files_vec) |> dplyr::mutate( - parts = stringr::str_split(basename(ref_file), "_"), + parts = stringr::str_split(basename(ref_file), "_"), i_sparse = purrr::map_int(parts, ~ .get_idx(.x, "sparse.parquet")), - i_strat = purrr::map_int(parts, ~ { - if (is.null(stratify_by)) return(NA_integer_) + i_strat = purrr::map_int(parts, ~ { + if (is.null(stratify_by)) { + return(NA_integer_) + } .get_idx(.x, stratify_by) }), # Feature = last two tokens before sparse.parquet feature = purrr::map2_chr(parts, i_sparse, ~ { - i <- .y; x <- .x - if (is.na(i) || i < 3) return(NA_character_) + i <- .y + x <- .x + if (is.na(i) || i < 3) { + return(NA_character_) + } paste(x[(i - 2):(i - 1)], collapse = "_") }), # Drug or drug class extraction drug_or_class = purrr::map2_chr(parts, i_strat, ~ { - i <- .y; x <- .x + i <- .y + x <- .x # Stratified models if (!is.na(i)) { @@ -304,32 +310,40 @@ createMLinputList <- function(path, # Stratification value (if present) strat_value = purrr::map2_chr(parts, i_strat, ~ { - i <- .y; x <- .x - if (is.na(i)) return("") + i <- .y + x <- .x + if (is.na(i)) { + return("") + } # default position is two tokens after the strat label j <- i + 2 # if there's an intervening 'leaveout', skip over it if (j <= length(x) && identical(x[j], "leaveout")) j <- j + 1 - if (j <= length(x)) return(x[j]) - "" # no stratification + if (j <= length(x)) { + return(x[j]) + } + "" # no stratification }), # Prefix key for grouping prefix_key = purrr::map2_chr(parts, i_strat, ~ { - i <- .y; x <- .x + i <- .y + x <- .x # Case A: stratified -> prefix before the stratify label if (!is.na(i)) { - if (i - 1 >= 1) return(paste(x[1:(i - 1)], collapse = "_")) + if (i - 1 >= 1) { + return(paste(x[1:(i - 1)], collapse = "_")) + } return("") } # Case B: unstratified -> prefix is first two tokens - if (x[2] == "drug" && x[3] != "class"){ + if (x[2] == "drug" && x[3] != "class") { # Case A: Cje_drug_X return(paste(x[1:2], collapse = "_")) } - if (x[2] == "drug" && x[3] == "class"){ + if (x[2] == "drug" && x[3] == "class") { # Case A: Cje_drug_X return(paste(x[1:3], collapse = "_")) } @@ -345,18 +359,17 @@ createMLinputList <- function(path, test_file = NA_character_, output_prefix = gsub("_sparse\\.parquet$", "", basename(ref_file)), matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) - # ============================ - # Cross-test modeling, no LOO - # ============================ + # ============================ + # Cross-test modeling, no LOO + # ============================ } else if (cross_test && !LOO) { - if (is.null(stratify_by)) { # Case A: stratify_by = NULL, pair across abx within same feature + prefix pairs <- parsed |> @@ -366,8 +379,10 @@ createMLinputList <- function(path, dplyr::select(test_file = ref_file, feature, prefix_key, strat_value, test_drug = drug_or_class), by = c("feature", "prefix_key", "strat_value") ) |> - dplyr::filter(ref_file != test_file, - ref_drug != test_drug) |> + dplyr::filter( + ref_file != test_file, + ref_drug != test_drug + ) |> dplyr::distinct() |> dplyr::mutate( output_prefix = paste0( @@ -380,10 +395,10 @@ createMLinputList <- function(path, out <- pairs |> dplyr::mutate( matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) @@ -392,30 +407,29 @@ createMLinputList <- function(path, # Case B: stratify_by != NULL, pair same drug/class, prefix, feature, # but across different stratification groups pairs <- parsed |> - dplyr::select(ref_file, feature, prefix_key, strat_value, - drug_or_class) |> - + dplyr::select( + ref_file, feature, prefix_key, strat_value, + drug_or_class + ) |> # self-join ONLY on prefix_key, drug/class, feature dplyr::inner_join( parsed |> - dplyr::select(test_file = ref_file, - feature, prefix_key, strat_value_test = strat_value, - drug_or_class), + dplyr::select( + test_file = ref_file, + feature, prefix_key, strat_value_test = strat_value, + drug_or_class + ), by = c("prefix_key", "feature", "drug_or_class") ) |> - # do NOT test file against itself dplyr::filter(ref_file != test_file) |> - # enforce different stratification group dplyr::filter(strat_value != strat_value_test) |> - # remove symmetric duplicates (A,B == B,A) dplyr::rowwise() |> dplyr::mutate(pair_id = paste(sort(c(ref_file, test_file)), collapse = "||")) |> dplyr::ungroup() |> dplyr::distinct(pair_id, .keep_all = TRUE) |> - dplyr::mutate( output_prefix = paste0( prefix_key, "_", @@ -429,19 +443,18 @@ createMLinputList <- function(path, out <- pairs |> dplyr::mutate( matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) - # ============================ - # Cross-test + LOO modeling - # ============================ + # ============================ + # Cross-test + LOO modeling + # ============================ } else if (cross_test && LOO) { - # LOO requires special directory structure resolution test_path <- file.path(path, stringr::str_remove(basename(paths$matrix_path), "^LOO_")) test_path <- normalizePath(test_path) @@ -461,10 +474,10 @@ createMLinputList <- function(path, out <- loo_pairs |> dplyr::mutate( matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) @@ -472,9 +485,11 @@ createMLinputList <- function(path, } # If we ever get here, something wasn't covered - stop("Unhandled combination of arguments: ", - "MDR=", MDR, ", cross_test=", cross_test, ", LOO=", LOO, - ", stratify_by=", if (is.null(stratify_by)) "NULL" else stratify_by) + stop( + "Unhandled combination of arguments: ", + "MDR=", MDR, ", cross_test=", cross_test, ", LOO=", LOO, + ", stratify_by=", if (is.null(stratify_by)) "NULL" else stratify_by + ) } @@ -544,13 +559,15 @@ createMLinputList <- function(path, #' #' # Run with more threads and minimal output #' runMDRmodels("/path/to/results", -#' threads = 32, -#' verbose = FALSE) +#' threads = 32, +#' verbose = FALSE +#' ) #' #' # Run without saving model fits (save disk space) #' runMDRmodels("/path/to/results", -#' threads = 16, -#' return_fit = FALSE) +#' threads = 16, +#' return_fit = FALSE +#' ) #' } #' #' @seealso @@ -571,12 +588,12 @@ runMDRmodels <- function(path, use_saved_split = TRUE, shuffle_labels = FALSE, use_pca = FALSE) { - files <- createMLinputList(path, - stratify_by = NULL, - LOO = FALSE, - cross_test = FALSE, - MDR = TRUE) + stratify_by = NULL, + LOO = FALSE, + cross_test = FALSE, + MDR = TRUE + ) if (nrow(files) == 0) { message("No MDR files found to process. Exiting.") @@ -594,18 +611,19 @@ runMDRmodels <- function(path, # Auto tags for shuffled and PCA shuffle_tag <- if (isTRUE(shuffle_labels)) "shuffled_" else "" - pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" + pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" results_list <- future.apply::future_lapply( seq_len(nrow(files)), FUN = function(i) { - - ref_parquet <- files$ref_file[i] + ref_parquet <- files$ref_file[i] output_prefix <- files$output_prefix[i] if (interactive()) { - message(sprintf("[runMDRmodels] %d/%d: %s", - i, nrow(files), basename(ref_parquet))) + message(sprintf( + "[runMDRmodels] %d/%d: %s", + i, nrow(files), basename(ref_parquet) + )) } ml_input <- loadMLInputTibble(ref_parquet) @@ -619,32 +637,37 @@ runMDRmodels <- function(path, list(split = split, seed = 5280, n_fold = n_fold) } - res <- try({ - runMLPipeline( - ml_input_tibble = ml_input, - test_data = NA, - model = "LR", - split = sp$split, - n_fold = sp$n_fold, - prop_vi_top_feats = prop_vi_top_feats, - n_top_feats = NA, - use_pca = use_pca, - pca_threshold = pca_threshold, - shuffle_labels = shuffle_labels, - penalty_vec = 10^seq(-4, -1, length.out = 10), - mix_vec = 0:5 / 5, - select_best_metric = "mcc", - seed = sp$seed, - verbose = verbose, - return_tune_res = return_tune_res, - return_fit = return_fit, - return_pred = return_pred - ) - }, silent = TRUE) + res <- try( + { + runMLPipeline( + ml_input_tibble = ml_input, + test_data = NA, + model = "LR", + split = sp$split, + n_fold = sp$n_fold, + prop_vi_top_feats = prop_vi_top_feats, + n_top_feats = NA, + use_pca = use_pca, + pca_threshold = pca_threshold, + shuffle_labels = shuffle_labels, + penalty_vec = 10^seq(-4, -1, length.out = 10), + mix_vec = 0:5 / 5, + select_best_metric = "mcc", + seed = sp$seed, + verbose = verbose, + return_tune_res = return_tune_res, + return_fit = return_fit, + return_pred = return_pred + ) + }, + silent = TRUE + ) if (inherits(res, "try-error")) { - warning("Model failed for: ", output_prefix, - "\n Error: ", attr(res, "condition")$message) + warning( + "Model failed for: ", output_prefix, + "\n Error: ", attr(res, "condition")$message + ) return(NULL) } @@ -652,19 +675,25 @@ runMDRmodels <- function(path, base <- paste0(shuffle_tag, output_prefix, pca_tag) if (!is.null(res$performance_tibble)) { - readr::write_tsv(res$performance_tibble, - file.path(files$out_perf[i], paste0(base, "_performance.tsv"))) + readr::write_tsv( + res$performance_tibble, + file.path(files$out_perf[i], paste0(base, "_performance.tsv")) + ) } if (!is.null(res$top_feat_tibble)) { - readr::write_tsv(res$top_feat_tibble, - file.path(files$out_top[i], paste0(base, "_top_features.tsv"))) + readr::write_tsv( + res$top_feat_tibble, + file.path(files$out_top[i], paste0(base, "_top_features.tsv")) + ) } if (!is.null(res$fit)) { saveRDS(res$fit, file.path(files$out_models[i], paste0(base, "_model_fit.rds"))) } if (!is.null(res$pred)) { - readr::write_tsv(res$pred, - file.path(files$out_pred[i], paste0(base, "_prediction.tsv"))) + readr::write_tsv( + res$pred, + file.path(files$out_pred[i], paste0(base, "_prediction.tsv")) + ) } NULL @@ -783,21 +812,24 @@ runMDRmodels <- function(path, #' #' # Cross-test with year stratification #' runMLmodels("/path/to/results", -#' stratify_by = "year", -#' cross_test = TRUE, -#' threads = 32) +#' stratify_by = "year", +#' cross_test = TRUE, +#' threads = 32 +#' ) #' #' # LOO analysis stratified by country with cross-testing #' runMLmodels("/path/to/results", -#' stratify_by = "country", -#' LOO = TRUE, -#' cross_test = TRUE, -#' verbose = TRUE) +#' stratify_by = "country", +#' LOO = TRUE, +#' cross_test = TRUE, +#' verbose = TRUE +#' ) #' #' # Run without saving model fits (save disk space) #' runMLmodels("/path/to/results", -#' stratify_by = "year", -#' return_fit = FALSE) +#' stratify_by = "year", +#' return_fit = FALSE +#' ) #' } #' #' @seealso @@ -823,19 +855,21 @@ runMLmodels <- function(path, use_saved_split = TRUE, shuffle_labels = FALSE, use_pca = FALSE) { - if (!is.null(stratify_by)) { - if (!is.character(stratify_by) || length(stratify_by) != 1L) + if (!is.character(stratify_by) || length(stratify_by) != 1L) { stop("`stratify_by` must be NULL or a single string: 'year' or 'country'.") - if (!stratify_by %in% c("year", "country")) + } + if (!stratify_by %in% c("year", "country")) { stop("`stratify_by` must be NULL, 'year', or 'country'.") + } } files <- createMLinputList(path, - stratify_by = stratify_by, - LOO = LOO, - MDR = FALSE, - cross_test = cross_test) + stratify_by = stratify_by, + LOO = LOO, + MDR = FALSE, + cross_test = cross_test + ) if (nrow(files) == 0) { message("No files found to process. Exiting.") @@ -864,8 +898,7 @@ runMLmodels <- function(path, strat_suffix <- if (is.null(stratify_by) || identical(stratify_by, "")) { "" } else { - switch( - stratify_by, + switch(stratify_by, "country" = "_country", "year" = "_year", stop("`stratify_by` must be NULL, 'year', or 'country'.") @@ -874,18 +907,19 @@ runMLmodels <- function(path, # Auto naming for shuffled and PCA shuffle_tag <- if (isTRUE(shuffle_labels)) "shuffled_" else "" - pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" + pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" results_list <- future.apply::future_lapply( seq_len(nrow(files)), FUN = function(i) { - - ref_parquet <- files$ref_file[i] + ref_parquet <- files$ref_file[i] output_prefix <- files$output_prefix[i] if (interactive()) { - message(sprintf("[runMLmodels] %d/%d: %s", - i, nrow(files), basename(ref_parquet))) + message(sprintf( + "[runMLmodels] %d/%d: %s", + i, nrow(files), basename(ref_parquet) + )) } ml_input <- loadMLInputTibble(ref_parquet) @@ -910,32 +944,37 @@ runMLmodels <- function(path, list(split = split, seed = 5280, n_fold = n_fold) } - res <- try({ - runMLPipeline( - ml_input_tibble = ml_input, - test_data = test_data, - model = "LR", - split = sp$split, - n_fold = sp$n_fold, - prop_vi_top_feats = prop_vi_top_feats, - n_top_feats = NA, - use_pca = use_pca, - pca_threshold = pca_threshold, - shuffle_labels = shuffle_labels, - penalty_vec = 10^seq(-4, -1, length.out = 10), - mix_vec = 0:5 / 5, - select_best_metric = "mcc", - seed = sp$seed, - verbose = verbose, - return_tune_res = return_tune_res, - return_fit = return_fit, - return_pred = return_pred - ) - }, silent = TRUE) + res <- try( + { + runMLPipeline( + ml_input_tibble = ml_input, + test_data = test_data, + model = "LR", + split = sp$split, + n_fold = sp$n_fold, + prop_vi_top_feats = prop_vi_top_feats, + n_top_feats = NA, + use_pca = use_pca, + pca_threshold = pca_threshold, + shuffle_labels = shuffle_labels, + penalty_vec = 10^seq(-4, -1, length.out = 10), + mix_vec = 0:5 / 5, + select_best_metric = "mcc", + seed = sp$seed, + verbose = verbose, + return_tune_res = return_tune_res, + return_fit = return_fit, + return_pred = return_pred + ) + }, + silent = TRUE + ) if (inherits(res, "try-error")) { - warning("Model failed for: ", output_prefix, - "\n Error: ", attr(res, "condition")$message) + warning( + "Model failed for: ", output_prefix, + "\n Error: ", attr(res, "condition")$message + ) return(NULL) } @@ -943,19 +982,25 @@ runMLmodels <- function(path, base <- paste0(shuffle_tag, config_prefix, output_prefix, pca_tag, strat_suffix) if (!is.null(res$performance_tibble)) { - readr::write_tsv(res$performance_tibble, - file.path(files$out_perf[i], paste0(base, "_performance.tsv"))) + readr::write_tsv( + res$performance_tibble, + file.path(files$out_perf[i], paste0(base, "_performance.tsv")) + ) } if (!is.null(res$top_feat_tibble)) { - readr::write_tsv(res$top_feat_tibble, - file.path(files$out_top[i], paste0(base, "_top_features.tsv"))) + readr::write_tsv( + res$top_feat_tibble, + file.path(files$out_top[i], paste0(base, "_top_features.tsv")) + ) } if (!is.null(res$fit)) { saveRDS(res$fit, file.path(files$out_models[i], paste0(base, "_model_fit.rds"))) } if (!is.null(res$pred)) { - readr::write_tsv(res$pred, - file.path(files$out_pred[i], paste0(base, "_prediction.tsv"))) + readr::write_tsv( + res$pred, + file.path(files$out_pred[i], paste0(base, "_prediction.tsv")) + ) } NULL @@ -973,7 +1018,6 @@ runMLmodels <- function(path, } - #' Run the entire AMR ML pipeline from a parquet-backed DuckDB #' #' This function provides a complete end-to-end AMR machine learning workflow. @@ -1006,11 +1050,12 @@ runModelingPipeline <- function(parquet_duckdb_path, pca_threshold = 0.99, verbose = TRUE, use_saved_split = TRUE) { - parquet_duckdb_path <- normalizePath(parquet_duckdb_path) if (!file.exists(parquet_duckdb_path)) { - stop("Parquet-backed DuckDB at ", parquet_duckdb_path, " not found.\n", - "Are you using `{Bug}.duckdb` instead of `{Bug}_parquet.duckdb?`") + stop( + "Parquet-backed DuckDB at ", parquet_duckdb_path, " not found.\n", + "Are you using `{Bug}.duckdb` instead of `{Bug}_parquet.duckdb?`" + ) } out_root <- dirname(parquet_duckdb_path) @@ -1024,9 +1069,9 @@ runModelingPipeline <- function(parquet_duckdb_path, generateMLInputs( parquet_duckdb_path = parquet_duckdb_path, out_path = out_root, - n_fold = n_fold, - split = split, - min_n = min_n, + n_fold = n_fold, + split = split, + min_n = min_n, verbosity = if (verbose) "minimal" else "debug" ) @@ -1089,12 +1134,13 @@ runModelingPipeline <- function(parquet_duckdb_path, # All done! if (verbose) { message("\n=== AMR-ML Pipeline Complete ===") - message("All matrices, models, top feature lists, and performance metrics saved under:\n ", - out_root) + message( + "All matrices, models, top feature lists, and performance metrics saved under:\n ", + out_root + ) message("\nTo inspect model outputs, see directories such as:") message(" ML_performance/, ML_models/, ML_prediction/, ML_top_features/") } invisible(out_root) } - diff --git a/R/run_ml_pipeline.R b/R/run_ml_pipeline.R index 2a97c00..ad9f691 100644 --- a/R/run_ml_pipeline.R +++ b/R/run_ml_pipeline.R @@ -93,20 +93,21 @@ runMLPipeline <- function( .checkArgReturnPred(return_pred) - # Set `n_fold` to `NA` if not using cross-validation. if (split[2] != 0) { n_fold <- NA } # Confirm resolved split params - if (verbose) { - mode <- if (split[2] == 0) "cv" else "splits" - message(sprintf("ML split mode: %s | split = c(%.2f, %.2f) | n_fold = %s | seed = %s", - mode, split[1], split[2], - ifelse(is.na(n_fold), "NA", as.character(n_fold)), - as.character(seed))) - } + if (verbose) { + mode <- if (split[2] == 0) "cv" else "splits" + message(sprintf( + "ML split mode: %s | split = c(%.2f, %.2f) | n_fold = %s | seed = %s", + mode, split[1], split[2], + ifelse(is.na(n_fold), "NA", as.character(n_fold)), + as.character(seed) + )) + } # Create a variable indicating whether external `test_data` was provided. This # will be set to `TRUE` later if the `test_data` argument is not `NA`. @@ -116,10 +117,10 @@ runMLPipeline <- function( # Determine whether multi-class classification is to be performed. if (as.character(.getTargetVarName(ml_input_tibble)) == "resistant_classes") { - multi_class <- TRUE - } else { - multi_class <- FALSE - } + multi_class <- TRUE + } else { + multi_class <- FALSE + } if (model != "LR" & multi_class) { stop(paste( @@ -262,7 +263,7 @@ runMLPipeline <- function( mix_vec = mix_vec ) } - + recipe <- buildRecipe(train_data, use_pca = use_pca, pca_threshold = pca_threshold @@ -421,14 +422,16 @@ runMLPipeline <- function( all_results[["fit"]] <- fit } - if(return_pred) { - if(!multi_class){ + if (return_pred) { + if (!multi_class) { all_results[["pred"]] <- test_data_plus_predictions |> - dplyr::select(c(genome_id, .pred_class, .pred_Resistant, - .pred_Susceptible, genome_drug.resistant_phenotype)) - } - all_results[["pred"]] <- test_data_plus_predictions + dplyr::select(c( + genome_id, .pred_class, .pred_Resistant, + .pred_Susceptible, genome_drug.resistant_phenotype + )) } + all_results[["pred"]] <- test_data_plus_predictions + } return(all_results) } diff --git a/vignettes/intro.Rmd b/vignettes/intro.Rmd index 996eb6b..af5bc8e 100644 --- a/vignettes/intro.Rmd +++ b/vignettes/intro.Rmd @@ -264,19 +264,19 @@ ml_tibble_reduced <- removeTopFeats(ml_tibble, top_features) ### Precision-recall curve ```{r plot-prc} -test_data_plus_predictions <- readr::read_tsv(results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv) +test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) plotPRC(test_data_plus_predictions) ``` ### ROC curve ```{r plot-roc} -test_data_plus_predictions <- readr::read_tsv(results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv) +test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) plotROC(test_data_plus_predictions) ``` ### Variable importance plot ```{r plot-vi} -topfeat <- readr::read_tsv(results/ML_top_features/Sfl_drug_AMP_domains_binary_top_features.tsv) +topfeat <- readr::read_tsv(results / ML_top_features / Sfl_drug_AMP_domains_binary_top_features.tsv) plotTopFeatsVI(topfeat) ``` ### Baseline comparison barplot @@ -326,7 +326,6 @@ You can label the top N features to highlight the strongest hits (default is 5) ```{r} plotFishers(fisher_results) plotFishers(fisher_results, alpha = 0.01, label_top_n = 5) - ``` ## Wrapper to run all models @@ -338,14 +337,15 @@ Given a DuckDB file produced by `runDataProcessing()`, it: 5. saves performance metrics, fitted models, predictions, and top feature rankings ``` {r} runModelingPipeline(parquet_duckdb_path, - threads = 16, - n_fold = 5, - split = c(1, 0), - min_n = 25, - prop_vi_top_feats = c(0, 1), - pca_threshold = 0.99, - verbose = TRUE, - use_saved_split = TRUE) + threads = 16, + n_fold = 5, + split = c(1, 0), + min_n = 25, + prop_vi_top_feats = c(0, 1), + pca_threshold = 0.99, + verbose = TRUE, + use_saved_split = TRUE +) ``` Merge the performance and top features of each kind of models into a parquet that will serve as starting data for `amRshiny` package @@ -357,7 +357,7 @@ buildPerformancePq( LOO = FALSE, MDR = FALSE, cross_test = FALSE, - out_parquet = NULL, + out_parquet = NULL, compression = "zstd", verbose = TRUE ) @@ -367,8 +367,8 @@ buildTopFeatsPq( LOO = FALSE, MDR = FALSE, cross_test = FALSE, - out_parquet = NULL, + out_parquet = NULL, compression = "zstd", verbose = TRUE -) +) ``` From 62668b41965ff57b8832fcfa17ebd0df404436c1 Mon Sep 17 00:00:00 2001 From: Abhirupa Ghosh <100681585+AbhirupaGhosh@users.noreply.github.com> Date: Mon, 23 Mar 2026 14:41:19 -0600 Subject: [PATCH 3/4] Implement computeFeatureScore function for feature scoring Added computeFeatureScore function to calculate feature scores based on various metrics and cluster information. --- R/feature_rescoring.R | 198 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 198 insertions(+) diff --git a/R/feature_rescoring.R b/R/feature_rescoring.R index 5c4fa0b..b9a66da 100644 --- a/R/feature_rescoring.R +++ b/R/feature_rescoring.R @@ -69,3 +69,201 @@ computeFeatureImprovement <- function( return(features_rescored) } + +computeFeatureScore <- function( + all_feature_parquet, + feature_cluster_parquet +) { + top_features <- arrow::read_parquet(normalizePath(all_feature_parquet)) +top_features <- top_features |> + dplyr::mutate( + shuffled = stringr::str_detect(prefix_key, "^shuffled_"), + species = prefix_key |> + stringr::str_remove("^shuffled_") |> + stringr::str_remove("(_drug_class|_drug).*") + ) + +## -------------------------------- +## 1. Scale capability (data-driven) +## -------------------------------- + +scale_capability <- top_features |> + dplyr::distinct(feature_type, feature_subtype) |> + dplyr::group_by(feature_type) |> + dplyr::summarise( + expected_types = dplyr::n_distinct(feature_subtype), + expected_types_csv = paste(sort(feature_subtype), collapse = ","), + .groups = "drop" + ) + +## -------------------------------- +## 2. Importance → contribution → rank +## -------------------------------- + +ranked_features <- top_features |> + dplyr::group_by( + species, + drug_label, + drug_or_class, + feature_type, + feature_subtype, + shuffled + ) |> + dplyr::mutate( + contribution = Importance / sum(Importance, na.rm = TRUE), + rank = dplyr::dense_rank(dplyr::desc(contribution)), + max_rank = max(rank), + min_rank = min(rank), + mean_rank = mean(rank), + rank_score = 1 - (mean_rank - 1) / max_rank, + ) |> + dplyr::ungroup() |> +dplyr::select( + species, + drug_label, + drug_or_class, + feature_type, + feature_subtype, + shuffled, + Variable, + contribution, + rank, + max_rank, + min_rank, + mean_rank +) + +## -------------------------------- +## 3. Collapse to protein level +## -------------------------------- + +feature_cluster <- arrow::read_parquet(normalizePath(feature_cluster_parquet)) + +ranked_features <- ranked_features |> +dplyr::mutate( + Variable = dplyr::case_when( + feature_type == "domains" ~ sub("_.+$", "", Variable), + feature_type == "proteins" ~ sub("fig.", "fig|", Variable, fixed = TRUE), + TRUE ~ Variable + ) + ) |> +dplyr::left_join(feature_cluster, by = dplyr::join_by(Variable == feature)) + +protein_scale_realization <- ranked_features |> + dplyr::filter(!is.na(cluster), shuffled == FALSE) |> + dplyr::group_by( + species, + drug_label, + drug_or_class, + cluster, + feature_type + ) |> + dplyr::summarise( + observed_types = dplyr::n_distinct(feature_subtype), + observed_types_csv = paste(sort(unique(feature_subtype)), collapse = ","), + .groups = "drop" + ) |> + dplyr::left_join(scale_capability, by = "feature_type") |> + dplyr::mutate( + scale_realization = dplyr::case_when( + observed_types == expected_types ~ "full_realization", + observed_types < expected_types ~ "partial_realization" + ) + ) + +protein_scale_summary <- protein_scale_realization |> + dplyr::group_by(species,drug_label,drug_or_class, cluster) |> + dplyr::summarise( + n_scales = dplyr::n_distinct(feature_type), + fully_realized_scales = + sum(scale_realization == "full_realization"), + partially_realized_scales = + sum(scale_realization == "partial_realization"), + scale_support_csv = + paste( + feature_type, + "(", observed_types_csv, "/", expected_types, ")", + collapse = "; " + ), + .groups = "drop" + ) |> + dplyr::mutate( + realization_score = + (fully_realized_scales + + 0.5 * partially_realized_scales) / + (fully_realized_scales + partially_realized_scales) + )|> + dplyr::mutate( + coverage_boost = + n_scales / max(n_scales) + ) |> + dplyr::mutate( + scale_factor = realization_score * coverage_boost + ) + +## -------------------------------- +## 4. Shuffle vs non-shuffle +## -------------------------------- + +shuffle_delta <- ranked_features |> + dplyr::filter(!is.na(cluster)) |> + dplyr::group_by( + species,drug_label, + drug_or_class, + cluster + ) |> +dplyr::summarise(mean_rank_nonshuffle = mean(mean_rank[shuffled == FALSE], na.rm = TRUE), + mean_rank_shuffle = mean(mean_rank[shuffled == TRUE], na.rm = TRUE), + delta_rank = mean_rank_shuffle - mean_rank_nonshuffle , + # If non-shuffled is missing -> NA (no evidence). If shuffled missing -> +Inf improvement. + improvement = dplyr::case_when( + !is.na(mean_rank_nonshuffle) ~ tidyr::replace_na(mean_rank_shuffle, Inf) - mean_rank_nonshuffle, + TRUE ~ NA_real_ + ), + + # "Good" if non-shuffled exists AND (non-shuffled < shuffled OR shuffled is missing) + good_feature = !is.na(mean_rank_nonshuffle) & improvement > 0, + .groups = "drop" ) + +## -------------------------------- +## 5. Cluster scoring +## -------------------------------- +robustness <- ranked_features |> +dplyr::filter(!is.na(cluster), shuffled == FALSE) |> +dplyr::group_by(species, drug_label,drug_or_class, cluster) |> +dplyr::summarize(median_contribution = median(contribution), + median_rank = median(mean_rank), .groups = "drop") |> +dplyr::left_join(protein_scale_summary |> + dplyr::distinct(species, drug_label,drug_or_class, cluster, n_scales, scale_factor), + by = c("species", "drug_label", "drug_or_class", "cluster")) |> +dplyr::left_join(shuffle_delta |> + dplyr::distinct(species, drug_label,drug_or_class, cluster, delta_rank, good_feature), + by = c("species", "drug_label","drug_or_class", "cluster"))|> + dplyr::group_by(species, drug_label,drug_or_class) |> + dplyr::mutate( + # higher = better + contrib_score = dplyr::percent_rank(median_contribution), + + # lower rank = better → invert + stability_score = 1 - dplyr::percent_rank(median_rank), + + # robustness: handle NaN = strongest case (missing in shuffle) + delta_score = dplyr::case_when( + is.nan(delta_rank) ~ 1, # best possible signal + delta_rank > 0 ~ dplyr::percent_rank(delta_rank), # reward + delta_rank == 0 ~ 0, # neutral + delta_rank < 0 ~ -dplyr::percent_rank(abs(delta_rank)) # penalize + ), + + # already bounded [0,1] + scale_score = scale_factor, + + robustness_score = + contrib_score * + stability_score * + scale_factor * + delta_score + ) + + return(robustness) +} From a3888418fa422c54b7d0b35f71579f74133d59e0 Mon Sep 17 00:00:00 2001 From: AbhirupaGhosh Date: Mon, 23 Mar 2026 20:46:33 +0000 Subject: [PATCH 4/4] Style code (GHA) --- R/feature_rescoring.R | 323 +++++++++++++++++++++--------------------- 1 file changed, 165 insertions(+), 158 deletions(-) diff --git a/R/feature_rescoring.R b/R/feature_rescoring.R index b9a66da..a3e3165 100644 --- a/R/feature_rescoring.R +++ b/R/feature_rescoring.R @@ -74,196 +74,203 @@ computeFeatureScore <- function( all_feature_parquet, feature_cluster_parquet ) { - top_features <- arrow::read_parquet(normalizePath(all_feature_parquet)) -top_features <- top_features |> - dplyr::mutate( - shuffled = stringr::str_detect(prefix_key, "^shuffled_"), - species = prefix_key |> - stringr::str_remove("^shuffled_") |> - stringr::str_remove("(_drug_class|_drug).*") - ) + top_features <- arrow::read_parquet(normalizePath(all_feature_parquet)) + top_features <- top_features |> + dplyr::mutate( + shuffled = stringr::str_detect(prefix_key, "^shuffled_"), + species = prefix_key |> + stringr::str_remove("^shuffled_") |> + stringr::str_remove("(_drug_class|_drug).*") + ) -## -------------------------------- -## 1. Scale capability (data-driven) -## -------------------------------- + ## -------------------------------- + ## 1. Scale capability (data-driven) + ## -------------------------------- -scale_capability <- top_features |> - dplyr::distinct(feature_type, feature_subtype) |> - dplyr::group_by(feature_type) |> - dplyr::summarise( - expected_types = dplyr::n_distinct(feature_subtype), - expected_types_csv = paste(sort(feature_subtype), collapse = ","), - .groups = "drop" - ) + scale_capability <- top_features |> + dplyr::distinct(feature_type, feature_subtype) |> + dplyr::group_by(feature_type) |> + dplyr::summarise( + expected_types = dplyr::n_distinct(feature_subtype), + expected_types_csv = paste(sort(feature_subtype), collapse = ","), + .groups = "drop" + ) -## -------------------------------- -## 2. Importance → contribution → rank -## -------------------------------- + ## -------------------------------- + ## 2. Importance → contribution → rank + ## -------------------------------- -ranked_features <- top_features |> - dplyr::group_by( - species, + ranked_features <- top_features |> + dplyr::group_by( + species, drug_label, - drug_or_class, - feature_type, - feature_subtype, + drug_or_class, + feature_type, + feature_subtype, shuffled - ) |> - dplyr::mutate( - contribution = Importance / sum(Importance, na.rm = TRUE), - rank = dplyr::dense_rank(dplyr::desc(contribution)), - max_rank = max(rank), + ) |> + dplyr::mutate( + contribution = Importance / sum(Importance, na.rm = TRUE), + rank = dplyr::dense_rank(dplyr::desc(contribution)), + max_rank = max(rank), min_rank = min(rank), mean_rank = mean(rank), rank_score = 1 - (mean_rank - 1) / max_rank, - ) |> - dplyr::ungroup() |> -dplyr::select( - species, + ) |> + dplyr::ungroup() |> + dplyr::select( + species, drug_label, - drug_or_class, - feature_type, - feature_subtype, - shuffled, - Variable, - contribution, - rank, - max_rank, - min_rank, - mean_rank -) + drug_or_class, + feature_type, + feature_subtype, + shuffled, + Variable, + contribution, + rank, + max_rank, + min_rank, + mean_rank + ) -## -------------------------------- -## 3. Collapse to protein level -## -------------------------------- + ## -------------------------------- + ## 3. Collapse to protein level + ## -------------------------------- -feature_cluster <- arrow::read_parquet(normalizePath(feature_cluster_parquet)) - -ranked_features <- ranked_features |> -dplyr::mutate( + feature_cluster <- arrow::read_parquet(normalizePath(feature_cluster_parquet)) + + ranked_features <- ranked_features |> + dplyr::mutate( Variable = dplyr::case_when( - feature_type == "domains" ~ sub("_.+$", "", Variable), + feature_type == "domains" ~ sub("_.+$", "", Variable), feature_type == "proteins" ~ sub("fig.", "fig|", Variable, fixed = TRUE), TRUE ~ Variable ) ) |> -dplyr::left_join(feature_cluster, by = dplyr::join_by(Variable == feature)) + dplyr::left_join(feature_cluster, by = dplyr::join_by(Variable == feature)) -protein_scale_realization <- ranked_features |> - dplyr::filter(!is.na(cluster), shuffled == FALSE) |> - dplyr::group_by( - species, + protein_scale_realization <- ranked_features |> + dplyr::filter(!is.na(cluster), shuffled == FALSE) |> + dplyr::group_by( + species, drug_label, - drug_or_class, - cluster, - feature_type - ) |> - dplyr::summarise( - observed_types = dplyr::n_distinct(feature_subtype), - observed_types_csv = paste(sort(unique(feature_subtype)), collapse = ","), - .groups = "drop" - ) |> - dplyr::left_join(scale_capability, by = "feature_type") |> - dplyr::mutate( - scale_realization = dplyr::case_when( - observed_types == expected_types ~ "full_realization", - observed_types < expected_types ~ "partial_realization" + drug_or_class, + cluster, + feature_type + ) |> + dplyr::summarise( + observed_types = dplyr::n_distinct(feature_subtype), + observed_types_csv = paste(sort(unique(feature_subtype)), collapse = ","), + .groups = "drop" + ) |> + dplyr::left_join(scale_capability, by = "feature_type") |> + dplyr::mutate( + scale_realization = dplyr::case_when( + observed_types == expected_types ~ "full_realization", + observed_types < expected_types ~ "partial_realization" + ) ) - ) -protein_scale_summary <- protein_scale_realization |> - dplyr::group_by(species,drug_label,drug_or_class, cluster) |> - dplyr::summarise( - n_scales = dplyr::n_distinct(feature_type), - fully_realized_scales = - sum(scale_realization == "full_realization"), - partially_realized_scales = - sum(scale_realization == "partial_realization"), - scale_support_csv = - paste( - feature_type, - "(", observed_types_csv, "/", expected_types, ")", - collapse = "; " - ), - .groups = "drop" - ) |> - dplyr::mutate( - realization_score = - (fully_realized_scales + - 0.5 * partially_realized_scales) / - (fully_realized_scales + partially_realized_scales) - )|> - dplyr::mutate( - coverage_boost = - n_scales / max(n_scales) - ) |> - dplyr::mutate( - scale_factor = realization_score * coverage_boost - ) + protein_scale_summary <- protein_scale_realization |> + dplyr::group_by(species, drug_label, drug_or_class, cluster) |> + dplyr::summarise( + n_scales = dplyr::n_distinct(feature_type), + fully_realized_scales = + sum(scale_realization == "full_realization"), + partially_realized_scales = + sum(scale_realization == "partial_realization"), + scale_support_csv = + paste( + feature_type, + "(", observed_types_csv, "/", expected_types, ")", + collapse = "; " + ), + .groups = "drop" + ) |> + dplyr::mutate( + realization_score = + (fully_realized_scales + + 0.5 * partially_realized_scales) / + (fully_realized_scales + partially_realized_scales) + ) |> + dplyr::mutate( + coverage_boost = + n_scales / max(n_scales) + ) |> + dplyr::mutate( + scale_factor = realization_score * coverage_boost + ) -## -------------------------------- -## 4. Shuffle vs non-shuffle -## -------------------------------- + ## -------------------------------- + ## 4. Shuffle vs non-shuffle + ## -------------------------------- -shuffle_delta <- ranked_features |> - dplyr::filter(!is.na(cluster)) |> - dplyr::group_by( - species,drug_label, - drug_or_class, - cluster - ) |> -dplyr::summarise(mean_rank_nonshuffle = mean(mean_rank[shuffled == FALSE], na.rm = TRUE), - mean_rank_shuffle = mean(mean_rank[shuffled == TRUE], na.rm = TRUE), - delta_rank = mean_rank_shuffle - mean_rank_nonshuffle , - # If non-shuffled is missing -> NA (no evidence). If shuffled missing -> +Inf improvement. + shuffle_delta <- ranked_features |> + dplyr::filter(!is.na(cluster)) |> + dplyr::group_by( + species, drug_label, + drug_or_class, + cluster + ) |> + dplyr::summarise( + mean_rank_nonshuffle = mean(mean_rank[shuffled == FALSE], na.rm = TRUE), + mean_rank_shuffle = mean(mean_rank[shuffled == TRUE], na.rm = TRUE), + delta_rank = mean_rank_shuffle - mean_rank_nonshuffle, + # If non-shuffled is missing -> NA (no evidence). If shuffled missing -> +Inf improvement. improvement = dplyr::case_when( !is.na(mean_rank_nonshuffle) ~ tidyr::replace_na(mean_rank_shuffle, Inf) - mean_rank_nonshuffle, - TRUE ~ NA_real_ + TRUE ~ NA_real_ ), # "Good" if non-shuffled exists AND (non-shuffled < shuffled OR shuffled is missing) good_feature = !is.na(mean_rank_nonshuffle) & improvement > 0, - .groups = "drop" ) + .groups = "drop" + ) -## -------------------------------- -## 5. Cluster scoring -## -------------------------------- -robustness <- ranked_features |> -dplyr::filter(!is.na(cluster), shuffled == FALSE) |> -dplyr::group_by(species, drug_label,drug_or_class, cluster) |> -dplyr::summarize(median_contribution = median(contribution), - median_rank = median(mean_rank), .groups = "drop") |> -dplyr::left_join(protein_scale_summary |> - dplyr::distinct(species, drug_label,drug_or_class, cluster, n_scales, scale_factor), - by = c("species", "drug_label", "drug_or_class", "cluster")) |> -dplyr::left_join(shuffle_delta |> - dplyr::distinct(species, drug_label,drug_or_class, cluster, delta_rank, good_feature), - by = c("species", "drug_label","drug_or_class", "cluster"))|> - dplyr::group_by(species, drug_label,drug_or_class) |> - dplyr::mutate( - # higher = better - contrib_score = dplyr::percent_rank(median_contribution), + ## -------------------------------- + ## 5. Cluster scoring + ## -------------------------------- + robustness <- ranked_features |> + dplyr::filter(!is.na(cluster), shuffled == FALSE) |> + dplyr::group_by(species, drug_label, drug_or_class, cluster) |> + dplyr::summarize( + median_contribution = median(contribution), + median_rank = median(mean_rank), .groups = "drop" + ) |> + dplyr::left_join( + protein_scale_summary |> + dplyr::distinct(species, drug_label, drug_or_class, cluster, n_scales, scale_factor), + by = c("species", "drug_label", "drug_or_class", "cluster") + ) |> + dplyr::left_join( + shuffle_delta |> + dplyr::distinct(species, drug_label, drug_or_class, cluster, delta_rank, good_feature), + by = c("species", "drug_label", "drug_or_class", "cluster") + ) |> + dplyr::group_by(species, drug_label, drug_or_class) |> + dplyr::mutate( + # higher = better + contrib_score = dplyr::percent_rank(median_contribution), - # lower rank = better → invert - stability_score = 1 - dplyr::percent_rank(median_rank), + # lower rank = better → invert + stability_score = 1 - dplyr::percent_rank(median_rank), - # robustness: handle NaN = strongest case (missing in shuffle) - delta_score = dplyr::case_when( - is.nan(delta_rank) ~ 1, # best possible signal - delta_rank > 0 ~ dplyr::percent_rank(delta_rank), # reward - delta_rank == 0 ~ 0, # neutral - delta_rank < 0 ~ -dplyr::percent_rank(abs(delta_rank)) # penalize - ), + # robustness: handle NaN = strongest case (missing in shuffle) + delta_score = dplyr::case_when( + is.nan(delta_rank) ~ 1, # best possible signal + delta_rank > 0 ~ dplyr::percent_rank(delta_rank), # reward + delta_rank == 0 ~ 0, # neutral + delta_rank < 0 ~ -dplyr::percent_rank(abs(delta_rank)) # penalize + ), - # already bounded [0,1] - scale_score = scale_factor, - - robustness_score = - contrib_score * - stability_score * - scale_factor * - delta_score - ) + # already bounded [0,1] + scale_score = scale_factor, + robustness_score = + contrib_score * + stability_score * + scale_factor * + delta_score + ) - return(robustness) + return(robustness) }