From 0adea0b2f5ebad72b065c799cc2a00c13ab7036f Mon Sep 17 00:00:00 2001 From: PiotrTymoszuk <80723424+PiotrTymoszuk@users.noreply.github.com> Date: Sat, 9 Sep 2023 23:23:14 +0200 Subject: [PATCH] Model performance in subsets of data Implementation of the `split()` and `augment()` methods, which can help at assessing prediction quality in subsets of the training, CV and test data defined by levels of an explanatory variable --- DESCRIPTION | 65 ++++++++++--------- NAMESPACE | 5 ++ R/caretx_oop.R | 148 +++++++++++++++++++++++++++++++++++++++++- R/extraction.R | 5 +- R/imports.R | 1 + R/utils_analysis.R | 25 +++++-- man/augment.caretx.Rd | 28 ++++++++ man/split.caretx.Rd | 35 ++++++++++ 8 files changed, 272 insertions(+), 40 deletions(-) create mode 100644 man/augment.caretx.Rd create mode 100644 man/split.caretx.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 2314bad..b47bd5a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,32 +1,33 @@ -Package: caretExtra -Type: Package -Title: Extra functionality for predictions and quality control of caret models -Version: 1.1.0 -Description: Tools for user-friendly prediction of outcome in training, - cross-validation and test data, plotting the prediction results and - calculation of error statistics. -License: GPL-3 -Encoding: UTF-8 -LazyData: true -Authors@R: - person("Piotr", "Tymoszuk", , "piotr.s.tymoszuk@gmail.com", - role = c("aut", "cre"), comment = c(ORCID = "0000-0002-0398-6034")) -Imports: - coxed, - DescTools, - ggrepel, - plotROC, - qgam, - survival -Depends: - caret, - clustTools, - dplyr, - generics, - ggplot2, - purrr, - rlang, - stats, - tibble -RoxygenNote: 7.2.3 -Roxygen: list(markdown = TRUE) +Package: caretExtra +Type: Package +Title: Extra functionality for predictions and quality control of caret models +Version: 1.1.1 +Description: Tools for user-friendly prediction of outcome in training, + cross-validation and test data, plotting the prediction results and + calculation of error statistics. +License: GPL-3 +Encoding: UTF-8 +LazyData: true +Authors@R: + person("Piotr", "Tymoszuk", , "piotr.s.tymoszuk@gmail.com", + role = c("aut", "cre"), comment = c(ORCID = "0000-0002-0398-6034")) +Imports: + coxed, + DescTools, + ggrepel, + plotROC, + qgam, + stringi, + survival +Depends: + caret, + clustTools, + dplyr, + generics, + ggplot2, + purrr, + rlang, + stats, + tibble +RoxygenNote: 7.2.3 +Roxygen: list(markdown = TRUE) diff --git a/NAMESPACE b/NAMESPACE index 207d571..5d5f6f5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,6 @@ # Generated by roxygen2: do not edit by hand +S3method(augment,caretx) S3method(calibration,caretx) S3method(calibration,predx) S3method(classp,caretx) @@ -22,11 +23,13 @@ S3method(predict,caretx) S3method(print,predx) S3method(residuals,caretx) S3method(residuals,predx) +S3method(split,caretx) S3method(squared,caretx) S3method(squared,predx) S3method(summary,caretx) S3method(summary,predx) export(as_caretx) +export(augment.caretx) export(calibration.caretx) export(calibration.predx) export(caretx) @@ -51,6 +54,7 @@ export(predict.caretx) export(predx) export(residuals.caretx) export(residuals.predx) +export(split.caretx) export(squared) export(squared.caretx) export(squared.predx) @@ -66,6 +70,7 @@ importFrom(dplyr,filter) importFrom(dplyr,left_join) importFrom(dplyr,mutate) importFrom(dplyr,select) +importFrom(generics,augment) importFrom(generics,components) importFrom(ggplot2,aes) importFrom(ggplot2,ggplot) diff --git a/R/caretx_oop.R b/R/caretx_oop.R index b5e9b93..86f39b8 100644 --- a/R/caretx_oop.R +++ b/R/caretx_oop.R @@ -1,4 +1,4 @@ -# Specific OOP for the `caretx` class: prediction and plotting +# Specific OOP for the `caretx` class: prediction, augmentation and subsetting # Prediction methods ----- @@ -71,3 +71,149 @@ return(preds) } + +# Augmentation: extended data frames ------ + +#' Predictions with explanatory variables. +#' +#' @description +#' The `augment()` method for \code{\link{caretx}} objects derives the +#' predictions for the training, resample (CV), and, optionally, training +#' data set appended with the explanatory variables. +#' +#' @param x a \code{\link{caretx}} object. +#' @param newdata an optional data frame with the test data. +#' @param ... extra arguments, currently none. +#' +#' @return alist of data frames. Each of them contains the +#' observation number (`.observation`), +#' resample ID (only for resample/CV, `.resample`), +#' outcome (`.outcome`), fitted values/classes (`.fitted`). +#' For classification models, class assignment probabilities are returned +#' as well (columns named after levels of the `.outcome` variable). +#' +#' @export augment.caretx +#' @export + + augment.caretx <- function(x, newdata = NULL, ...) { + + ## entry control ------ + + stopifnot(is_caretx(x)) + + .outcome <- NULL + + ## training data and predictions ------ + + outcome_var <- as.character(formula(x))[[2]] + + train_data <- select(model.frame(x), -all_of(outcome_var)) + + train_data <- mutate(train_data, + .observation = 1:nrow(train_data)) + + preds <- map(compact(predict(x, newdata = newdata)), + model.frame) + + map(preds, + left_join, + train_data, + by = '.observation') + + } + +# Splitting ------ + +#' Split predictions by an explanatory factor. +#' +#' @description +#' The `split()` method for the \code{\link{caretx}} class generates predictions +#' for the training, resample (CV) and, optionally, test data set and splits +#' them by the levels of an explanatory factor present in the training data. +#' +#' @details +#' This method may be used to investigate quality of prediction in a particular +#' subset or subsets of the data set. The method returns a plain list of +#' \code{\link{predx}} objects, whose properties can be further explored +#' with the specific \code{\link{summary.predx}} and \code{\link{plot.predx}} +#' methods. +#' +#' @return a plain list of \code{\link{predx}} objects. +#' +#' @param x a \code{\link{caretx}} object. +#' @param f a splitting factor, in quoted or unquoted form. +#' @param drop logical, should unused levels of the splitting factor +#' f be dropped? +#' @param newdata an optional data frame with the test data. +#' @param ... extra arguments, currently none. +#' +#' @export split.caretx +#' @export + + split.caretx <- function(x, + f, + drop = FALSE, + newdata = NULL, ...) { + + ## entry control ------- + + stopifnot(is_caretx(x)) + stopifnot(is.logical(drop)) + + f <- rlang::as_string(rlang::ensym(f)) + + ## augmented data -------- + + aug_data <- augment(x, newdata = newdata) + + if(!f %in% names(aug_data[[1]])) { + + stop("'f' not found in the training data.", call. = FALSE) + + } + + split_vecs <- map(aug_data, ~.x[[f]]) + + split_data <- + map2(aug_data, split_vecs, split, drop = drop) + + split_data <- unlist(split_data, recursive = FALSE) + + ## predx objects -------- + + classes <- levels(split_data[[1]][['.outcome']]) + + split_data <- map(split_data, + select, + any_of(c('.observation', + '.resample', + '.outcome', + '.fitted', + classes))) + + pred_types <- + stringi::stri_replace(names(split_data), + regex = '\\..*$', + replacement = '') + + if(is.null(levels(split_data[[1]][['.outcome']]))) { + + mod_type <- 'regression' + + } else { + + outcome_lens <- length(levels(split_data[[1]][['.outcome']])) + + if(outcome_lens == 2) mod_type <- 'binary' else mod_type <- 'multi_class' + + } + + pmap(list(data = split_data, + prediction = pred_types), + predx, + classes = classes, + type = mod_type) + + } + +# END ------ diff --git a/R/extraction.R b/R/extraction.R index 12f8aa3..18ebe86 100644 --- a/R/extraction.R +++ b/R/extraction.R @@ -89,7 +89,7 @@ formula = formula(object), resamples = object$resamples, tuning = object$results, - best_tune = object$best_tune, + best_tune = object$bestTune, prediction = predict(object, newdata = newdata, ...), probability = classp(object, newdata = newdata, ...), square_dist = squared(object, newdata = newdata, ...), @@ -463,7 +463,8 @@ outcomes_num <- as.data.frame(DescTools::Dummy(pred_data[['.outcome']], - method = 'full')) + method = 'full', + levels = x$classes)) fitted_num <- pred_data[x$classes] diff --git a/R/imports.R b/R/imports.R index 3d338b6..67b229a 100644 --- a/R/imports.R +++ b/R/imports.R @@ -1,6 +1,7 @@ # Managing imports from dependencies #' @importFrom generics components +#' @importFrom generics augment #' #' @importFrom caret calibration #' @importFrom caret predict.train diff --git a/R/utils_analysis.R b/R/utils_analysis.R index 2574e01..37f9568 100644 --- a/R/utils_analysis.R +++ b/R/utils_analysis.R @@ -84,12 +84,19 @@ correlations <- purrr::transpose(correlations)$result - correlations$pearson <- c(rho = correlations$pearson$estimate, - lwr.ci = correlations$pearson$conf.int[1], - lwr.ci = correlations$pearson$conf.int[2]) + if(!is.null(correlations$pearson)) { - correlations <- - map(correlations, function(x) if(is.null(x)) c(NA, NA, NA) else x) + correlations$pearson <- c(rho = correlations$pearson$estimate, + lwr.ci = correlations$pearson$conf.int[1], + lwr.ci = correlations$pearson$conf.int[2]) + + } else { + + correlations$pearson <- c(rho = NA, + lwr.ci = NA, + lwr.ci = NA) + + } map2_dfr(correlations, names(correlations), @@ -476,6 +483,14 @@ splits <- split(data, factor(data[['.resample']])) + ## eliminating splits with single observations + ## no stats can be computed for them + + splits <- map(splits, + function(x) if(nrow(x) > 1) x else NULL) + + splits <- compact(splits) + split_stats <- map(splits, fun, ...) est_tbl <- map2_dfc(split_stats, names(split_stats), diff --git a/man/augment.caretx.Rd b/man/augment.caretx.Rd new file mode 100644 index 0000000..8aea4bc --- /dev/null +++ b/man/augment.caretx.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/caretx_oop.R +\name{augment.caretx} +\alias{augment.caretx} +\title{Predictions with explanatory variables.} +\usage{ +\method{augment}{caretx}(x, newdata = NULL, ...) +} +\arguments{ +\item{x}{a \code{\link{caretx}} object.} + +\item{newdata}{an optional data frame with the test data.} + +\item{...}{extra arguments, currently none.} +} +\value{ +alist of data frames. Each of them contains the +observation number (\code{.observation}), +resample ID (only for resample/CV, \code{.resample}), +outcome (\code{.outcome}), fitted values/classes (\code{.fitted}). +For classification models, class assignment probabilities are returned +as well (columns named after levels of the \code{.outcome} variable). +} +\description{ +The \code{augment()} method for \code{\link{caretx}} objects derives the +predictions for the training, resample (CV), and, optionally, training +data set appended with the explanatory variables. +} diff --git a/man/split.caretx.Rd b/man/split.caretx.Rd new file mode 100644 index 0000000..1a7533a --- /dev/null +++ b/man/split.caretx.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/caretx_oop.R +\name{split.caretx} +\alias{split.caretx} +\title{Split predictions by an explanatory factor.} +\usage{ +\method{split}{caretx}(x, f, drop = FALSE, newdata = NULL, ...) +} +\arguments{ +\item{x}{a \code{\link{caretx}} object.} + +\item{f}{a splitting factor, in quoted or unquoted form.} + +\item{drop}{logical, should unused levels of the splitting factor +f be dropped?} + +\item{newdata}{an optional data frame with the test data.} + +\item{...}{extra arguments, currently none.} +} +\value{ +a plain list of \code{\link{predx}} objects. +} +\description{ +The \code{split()} method for the \code{\link{caretx}} class generates predictions +for the training, resample (CV) and, optionally, test data set and splits +them by the levels of an explanatory factor present in the training data. +} +\details{ +This method may be used to investigate quality of prediction in a particular +subset or subsets of the data set. The method returns a plain list of +\code{\link{predx}} objects, whose properties can be further explored +with the specific \code{\link{summary.predx}} and \code{\link{plot.predx}} +methods. +}