Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ Description: Comprehensive machine learning (ML) pipeline for predicting antimic
License: BSD_3_clause + file LICENSE
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.3
Depends:
R (>= 4.5.0)
Suggests:
BiocStyle,
ComplexHeatmap,
knitr,
rmarkdown,
Expand All @@ -34,17 +34,17 @@ VignetteBuilder: knitr
Config/testthat/edition: 3
Imports:
arrow,
BiocParallel,
DBI,
dplyr,
duckdb,
future,
future.apply,
ggplot2,
ggrepel,
glmnet,
glue,
jsonlite,
hardhat,
methods,
parsnip,
purrr,
readr,
Expand All @@ -57,14 +57,20 @@ Imports:
tidyr,
tune,
vip,
withr,
workflows,
workflowsets,
yardstick
biocViews:
ML,
AMR,
MicrobialGenomics,
Pathogen,
Software,
Classification,
Regression,
StatisticalMethod,
FeatureExtraction,
MultipleComparison,
FunctionalGenomics,
Genetics,
Visualization
URL: https://github.com/JRaviLab/amRml
BugReports: https://github.com/JRaviLab/amRml/issues
Config/roxygen2/version: 8.0.0
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ importFrom(graphics,barplot)
importFrom(hardhat,tune)
importFrom(jsonlite,fromJSON)
importFrom(jsonlite,write_json)
importFrom(methods,is)
importFrom(parsnip,augment)
importFrom(parsnip,boost_tree)
importFrom(parsnip,extract_fit_engine)
Expand Down
5 changes: 3 additions & 2 deletions R/arg_check_ml.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# This script contains functions to check the arguments of ML functions.

#' @importFrom methods is
#' @importFrom tibble is_tibble
NULL

Expand Down Expand Up @@ -254,7 +255,7 @@ NULL
#' `buildLRModel()` (random forest and boosted tree support planned)
#'
.checkArgParsnipMod <- function(parsnip_mod) {
if (class(parsnip_mod)[2] != "model_spec") {
if (!is(parsnip_mod, "model_spec")) {
stop("A `parsnip` model was expected but not received.")
}
}
Expand Down Expand Up @@ -418,7 +419,7 @@ NULL
#' @param tune_res Results of grid tuning, such as the output of `tuneGrid()`
#'
.checkArgTuneRes <- function(tune_res) {
if (class(tune_res)[1] != "tune_results") {
if (!is(tune_res, "tune_results")) {
stop("The `tune_res` argument can only take `tune_results` objects.")
}
}
Expand Down
143 changes: 128 additions & 15 deletions R/core_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#' @importFrom dplyr select
#' @importFrom dplyr slice
#' @importFrom hardhat tune
#' @importFrom methods is
#' @importFrom parsnip augment
#' @importFrom parsnip boost_tree
#' @importFrom parsnip extract_fit_engine
Expand Down Expand Up @@ -69,23 +70,34 @@ NULL
#' genome is resistant), but not both.
#' @param split [num] Vector of length 2 indicating the proportion of data to
#' be designated as training and validation, respectively.
#' @param seed [num] For reproducible analysis
#' @param seed [num] Optional. If supplied, the split is seeded (and the
#' caller's RNG state restored afterward) for standalone reproducibility. When
#' `NULL` (the default, as used by `runMLPipeline()`), the split inherits the
#' ambient RNG stream so it can share one seed with downstream tuning and fitting.
#' @return An `rsplit` object
#' @examples
#' ml <- tibble::tibble(
#' genome_id = paste0("g", 1:20),
#' genome_drug.resistant_phenotype = rep(c("Resistant", "Susceptible"), each = 10),
#' feat_a = rep(c(0L, 1L), 10),
#' feat_b = rep(c(1L, 0L), 10)
#' )
#' splitMLInputTibble(ml, split = c(1, 0), seed = 42)
#' @export
splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280) {
splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = NULL) {
.checkArgTibble(ml_input_tibble, ml = TRUE)
.checkArgSplit(split)
.checkArgSeed(seed)

set.seed(seed)
if (!is.null(seed)) {
.checkArgSeed(seed)
withr::local_seed(seed)
}

target_var <- .getTargetVarName(ml_input_tibble)

# Split the data, maintaining R/S proportions.
if (split[2] == 0) {
# If in CV mode:
# Still retain a stratified testing holdout purely for final reporting metrics;
# CV is only performed on the training portion.
# CV mode: retain a stratified testing holdout purely for final reporting
# metrics; CV is only performed on the training portion.
prop_train_for_holdout <- 0.8 # 80 percent train, 20 percent reserved test
data_split <- rsample::initial_split(
ml_input_tibble,
Expand All @@ -99,7 +111,6 @@ splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280
strata = !!target_var
)
}

return(data_split)
}

Expand All @@ -114,6 +125,14 @@ splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280
#' @param pca_threshold [num] The proportion of total variance for which the
#' principle components account
#' @return A `recipe` object
#' @examples
#' train <- tibble::tibble(
#' genome_id = paste0("g", 1:10),
#' genome_drug.resistant_phenotype = rep(c("Resistant", "Susceptible"), each = 5),
#' feat_a = rep(c(0L, 1L), 5),
#' feat_b = rep(c(1L, 0L), 5)
#' )
#' buildRecipe(train, use_pca = FALSE)
#' @export
buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) {
.checkArgTibble(train_data, ml = TRUE)
Expand Down Expand Up @@ -157,6 +176,9 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) {
#' @param multi_class [bool] Whether to construct a model for multi-class
#' classification
#' @return A `parsnip` `logistic_reg` object
#' @examples
#' buildLRModel()
#' buildLRModel(multi_class = TRUE)
#' @export
buildLRModel <- function(multi_class = FALSE) {
.checkArgMultiClass(multi_class)
Expand Down Expand Up @@ -186,6 +208,16 @@ buildLRModel <- function(multi_class = FALSE) {
#' `buildLRModel()` (random forest and boosted tree support planned)
#' @param recipe A recipe, such as the output of `buildRecipe()`
#' @return A `workflow` object
#' @examples
#' train <- tibble::tibble(
#' genome_id = paste0("g", 1:10),
#' genome_drug.resistant_phenotype = rep(c("Resistant", "Susceptible"), each = 5),
#' feat_a = rep(c(0L, 1L), 5),
#' feat_b = rep(c(1L, 0L), 5)
#' )
#' rec <- buildRecipe(train, use_pca = FALSE)
#' lr <- buildLRModel()
#' buildWflow(lr, rec)
#' @export
buildWflow <- function(parsnip_mod, recipe) {
.checkArgParsnipMod(parsnip_mod)
Expand All @@ -210,6 +242,12 @@ buildWflow <- function(parsnip_mod, recipe) {
#' regression. 0 corresponds to L2 regularization; 1 corresponds to L1;
#' intermediate values (0, 1) correspond to elastic net.
#' @return A logistic regression tuning grid as a tibble
#' @examples
#' buildTuningGrid(
#' model = "LR",
#' penalty_vec = 10^c(-3, -1),
#' mix_vec = c(0, 0.5, 1)
#' )
#' @export
buildTuningGrid <- function(
model = "LR",
Expand Down Expand Up @@ -243,27 +281,35 @@ buildTuningGrid <- function(
#' `buildTuningGrid()`
#' @param n_fold [num] Number of folds of cross-validation
#' @return Results of grid tuning
#' @examples
#' data(demo_ml_tibble)
#' data_split <- splitMLInputTibble(demo_ml_tibble, split = c(1, 0), seed = 1)
#' wflow <- buildWflow(
#' buildLRModel(),
#' buildRecipe(rsample::training(data_split))
#' )
#' grid <- buildTuningGrid("LR", 10^c(-3, -1), c(0, 0.5, 1))
#' set.seed(1)
#' tuneGrid(wflow, data_split, grid, n_fold = 2)
#' @export
tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"),
n_fold = 5) {
.checkArgTibble(grid)
.checkArgWflow(wflow)
.checkArgDataSplit(data_split)

split_class <- class(data_split)[1]

# Always do CV on the training portion of the split
train_df <- rsample::training(data_split)
target_var <- .getTargetVarName(train_df)

if (identical(split_class, "initial_split")) {
if (is(data_split, "initial_split")) {
# CV on training portion; final eval will use the held-out test set
resamples <- rsample::vfold_cv(train_df, v = n_fold, strata = !!target_var)
} else if (identical(split_class, "initial_validation_split")) {
} else if (is(data_split, "initial_validation_split")) {
# Use the validation partition from the original three-way split.
resamples <- rsample::validation_set(data_split)
} else {
stop("Unsupported rsample split object: ", split_class)
stop("Unsupported rsample split object: ", class(data_split)[1])
}

tune_res <- tune::tune_grid(
Expand Down Expand Up @@ -294,6 +340,19 @@ tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"),
#' @param select_best_metric [chr] Metric to select best model: "f_meas",
#' "pr_auc", "mcc", or "bal_accuracy"
#' @return Best model workflow
#' @examples
#' data(demo_ml_tibble)
#' data_split <- splitMLInputTibble(demo_ml_tibble, split = c(1, 0), seed = 1)
#' wflow <- buildWflow(
#' buildLRModel(),
#' buildRecipe(rsample::training(data_split))
#' )
#' set.seed(1)
#' tune_res <- tuneGrid(wflow, data_split,
#' buildTuningGrid("LR", 10^c(-3, -1), c(0, 0.5, 1)),
#' n_fold = 2
#' )
#' selectBestModel(tune_res, wflow, "mcc")
#' @export
selectBestModel <- function(tune_res, wflow, select_best_metric = "mcc") {
.checkArgTuneRes(tune_res)
Expand All @@ -315,6 +374,18 @@ selectBestModel <- function(tune_res, wflow, select_best_metric = "mcc") {
#' training. This can be the output of
#' `rsample::training(splitMLInputTibble(ml_input_tibble))`.
#' @return Best model fit
#' @examples
#' data(demo_ml_tibble)
#' data_split <- splitMLInputTibble(demo_ml_tibble, split = c(1, 0), seed = 1)
#' train <- rsample::training(data_split)
#' wflow <- buildWflow(buildLRModel(), buildRecipe(train))
#' set.seed(1)
#' tune_res <- tuneGrid(wflow, data_split,
#' buildTuningGrid("LR", 10^c(-3, -1), c(0, 0.5, 1)),
#' n_fold = 2
#' )
#' best_wflow <- selectBestModel(tune_res, wflow, "mcc")
#' fitBestModel(best_wflow, train)
#' @export
fitBestModel <- function(final_mod, train_data) {
.checkArgWflow(final_mod)
Expand Down Expand Up @@ -361,6 +432,11 @@ fitBestModel <- function(final_mod, train_data) {
#' `rsample::testing(splitMLInputTibble(ml_input_tibble))`.
#' @return Test data (tibble) with an added column for predicted phenotype
#' labels
#' @examples
#' data(demo_ml_tibble)
#' data(demo_fit)
#' data_split <- splitMLInputTibble(demo_ml_tibble, split = c(1, 0), seed = 1)
#' predictML(demo_fit, rsample::testing(data_split))
#' @export
predictML <- function(fit, test_data) {
.checkArgWflow(fit)
Expand All @@ -379,6 +455,22 @@ predictML <- function(fit, test_data) {
#' @param test_data_plus_predictions Test data (tibble) with an added column for
#' predicted phenotype labels, such as the output of `predictML()`
#' @return Confusion matrix of class `conf_mat`
#' @examples
#' preds <- tibble::tibble(
#' genome_id = paste0("g", 1:8),
#' genome_drug.resistant_phenotype = factor(
#' rep(c("Resistant", "Susceptible"), each = 4),
#' levels = c("Resistant", "Susceptible")
#' ),
#' .pred_class = factor(
#' c(
#' "Resistant", "Resistant", "Susceptible", "Resistant",
#' "Susceptible", "Resistant", "Susceptible", "Susceptible"
#' ),
#' levels = c("Resistant", "Susceptible")
#' )
#' )
#' getConfusionMatrix(preds)
#' @export
getConfusionMatrix <- function(test_data_plus_predictions) {
.checkArgTestDataPlusPredictions(test_data_plus_predictions)
Expand Down Expand Up @@ -613,6 +705,24 @@ getConfusionMatrix <- function(test_data_plus_predictions) {
#'
#' @inheritParams getConfusionMatrix
#' @return F1 score, AUPRC, balanced accuracy, nMCC, and log2(AUPRC/prior)
#' @examples
#' preds <- tibble::tibble(
#' genome_id = paste0("g", 1:10),
#' genome_drug.resistant_phenotype = factor(
#' rep(c("Resistant", "Susceptible"), each = 5),
#' levels = c("Resistant", "Susceptible")
#' ),
#' .pred_class = factor(
#' c(
#' "Resistant", "Resistant", "Susceptible", "Resistant", "Susceptible",
#' "Susceptible", "Resistant", "Susceptible", "Susceptible", "Resistant"
#' ),
#' levels = c("Resistant", "Susceptible")
#' ),
#' .pred_Resistant = c(0.9, 0.8, 0.4, 0.7, 0.3, 0.2, 0.6, 0.1, 0.2, 0.55),
#' .pred_Susceptible = c(0.1, 0.2, 0.6, 0.3, 0.7, 0.8, 0.4, 0.9, 0.8, 0.45)
#' )
#' calculateEvalMets(preds)
#' @export
calculateEvalMets <- function(test_data_plus_predictions) {
.checkArgTestDataPlusPredictions(test_data_plus_predictions)
Expand Down Expand Up @@ -644,6 +754,9 @@ calculateEvalMets <- function(test_data_plus_predictions) {
#' @return A tibble with a column for top features (`Variable`), a column for
#' `Importance`, and a column for `Sign` (or, for multi-class, a tibble with
#' per-class columns of importance scores for each `Variable`)
#' @examples
#' data(demo_fit)
#' extractTopFeats(demo_fit, n_top_feats = 10)
#' @export
extractTopFeats <- function(
fit, prop_vi_top_feats = c(0, 1),
Expand Down Expand Up @@ -691,7 +804,7 @@ extractTopFeats <- function(

# Take a different approach if using multi-class (the previous code would give
# a less meaningful result).
if (class(fit$fit$actions$model$spec)[1] == "multinom_reg") {
if (is(fit$fit$actions$model$spec, "multinom_reg")) {
warning(paste(
"Extracting top features from a multi-class model.",
"The `prop_vi_top_feats` and `n_top_feats` arguments do not apply."
Expand Down
18 changes: 18 additions & 0 deletions R/data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#' Demo ML input tibble
#'
#' Stratified subset (30 Resistant + 30 Susceptible) of the AMP-genes-binary
#' matrix from the bundled `Sfl_parquet.duckdb`, restricted to 80 feature
#' columns.
#'
#' @format A tibble with 60 rows and 82 columns: `genome_id`,
#' `genome_drug.resistant_phenotype`, and 80 binary feature columns.
#' @source `inst/scripts/make_demo_data.R`.
"demo_ml_tibble"

#' Demo LR fit
#'
#' A tuned logistic-regression workflow fitted on [demo_ml_tibble].
#'
#' @format A fitted `workflow` object (output of [fitBestModel()]).
#' @source `inst/scripts/make_demo_data.R`.
"demo_fit"
Loading