Skip to content

Commit

Permalink
Model performance in subsets of data
Browse files Browse the repository at this point in the history
Implementation of the `split()` and `augment()` methods, which can help at assessing prediction quality in subsets of the training, CV and test data defined by levels of  an explanatory variable
  • Loading branch information
PiotrTymoszuk authored Sep 9, 2023
1 parent 994817c commit 0adea0b
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 40 deletions.
65 changes: 33 additions & 32 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
Package: caretExtra
Type: Package
Title: Extra functionality for predictions and quality control of caret models
Version: 1.1.0
Description: Tools for user-friendly prediction of outcome in training,
cross-validation and test data, plotting the prediction results and
calculation of error statistics.
License: GPL-3
Encoding: UTF-8
LazyData: true
Authors@R:
person("Piotr", "Tymoszuk", , "[email protected]",
role = c("aut", "cre"), comment = c(ORCID = "0000-0002-0398-6034"))
Imports:
coxed,
DescTools,
ggrepel,
plotROC,
qgam,
survival
Depends:
caret,
clustTools,
dplyr,
generics,
ggplot2,
purrr,
rlang,
stats,
tibble
RoxygenNote: 7.2.3
Roxygen: list(markdown = TRUE)
Package: caretExtra
Type: Package
Title: Extra functionality for predictions and quality control of caret models
Version: 1.1.1
Description: Tools for user-friendly prediction of outcome in training,
cross-validation and test data, plotting the prediction results and
calculation of error statistics.
License: GPL-3
Encoding: UTF-8
LazyData: true
Authors@R:
person("Piotr", "Tymoszuk", , "[email protected]",
role = c("aut", "cre"), comment = c(ORCID = "0000-0002-0398-6034"))
Imports:
coxed,
DescTools,
ggrepel,
plotROC,
qgam,
stringi,
survival
Depends:
caret,
clustTools,
dplyr,
generics,
ggplot2,
purrr,
rlang,
stats,
tibble
RoxygenNote: 7.2.3
Roxygen: list(markdown = TRUE)
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

S3method(augment,caretx)
S3method(calibration,caretx)
S3method(calibration,predx)
S3method(classp,caretx)
Expand All @@ -22,11 +23,13 @@ S3method(predict,caretx)
S3method(print,predx)
S3method(residuals,caretx)
S3method(residuals,predx)
S3method(split,caretx)
S3method(squared,caretx)
S3method(squared,predx)
S3method(summary,caretx)
S3method(summary,predx)
export(as_caretx)
export(augment.caretx)
export(calibration.caretx)
export(calibration.predx)
export(caretx)
Expand All @@ -51,6 +54,7 @@ export(predict.caretx)
export(predx)
export(residuals.caretx)
export(residuals.predx)
export(split.caretx)
export(squared)
export(squared.caretx)
export(squared.predx)
Expand All @@ -66,6 +70,7 @@ importFrom(dplyr,filter)
importFrom(dplyr,left_join)
importFrom(dplyr,mutate)
importFrom(dplyr,select)
importFrom(generics,augment)
importFrom(generics,components)
importFrom(ggplot2,aes)
importFrom(ggplot2,ggplot)
Expand Down
148 changes: 147 additions & 1 deletion R/caretx_oop.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Specific OOP for the `caretx` class: prediction and plotting
# Specific OOP for the `caretx` class: prediction, augmentation and subsetting

# Prediction methods -----

Expand Down Expand Up @@ -71,3 +71,149 @@
return(preds)

}

# Augmentation: extended data frames ------

#' Predictions with explanatory variables.
#'
#' @description
#' The `augment()` method for \code{\link{caretx}} objects derives the
#' predictions for the training, resample (CV), and, optionally, training
#' data set appended with the explanatory variables.
#'
#' @param x a \code{\link{caretx}} object.
#' @param newdata an optional data frame with the test data.
#' @param ... extra arguments, currently none.
#'
#' @return alist of data frames. Each of them contains the
#' observation number (`.observation`),
#' resample ID (only for resample/CV, `.resample`),
#' outcome (`.outcome`), fitted values/classes (`.fitted`).
#' For classification models, class assignment probabilities are returned
#' as well (columns named after levels of the `.outcome` variable).
#'
#' @export augment.caretx
#' @export

augment.caretx <- function(x, newdata = NULL, ...) {

## entry control ------

stopifnot(is_caretx(x))

.outcome <- NULL

## training data and predictions ------

outcome_var <- as.character(formula(x))[[2]]

train_data <- select(model.frame(x), -all_of(outcome_var))

train_data <- mutate(train_data,
.observation = 1:nrow(train_data))

preds <- map(compact(predict(x, newdata = newdata)),
model.frame)

map(preds,
left_join,
train_data,
by = '.observation')

}

# Splitting ------

#' Split predictions by an explanatory factor.
#'
#' @description
#' The `split()` method for the \code{\link{caretx}} class generates predictions
#' for the training, resample (CV) and, optionally, test data set and splits
#' them by the levels of an explanatory factor present in the training data.
#'
#' @details
#' This method may be used to investigate quality of prediction in a particular
#' subset or subsets of the data set. The method returns a plain list of
#' \code{\link{predx}} objects, whose properties can be further explored
#' with the specific \code{\link{summary.predx}} and \code{\link{plot.predx}}
#' methods.
#'
#' @return a plain list of \code{\link{predx}} objects.
#'
#' @param x a \code{\link{caretx}} object.
#' @param f a splitting factor, in quoted or unquoted form.
#' @param drop logical, should unused levels of the splitting factor
#' f be dropped?
#' @param newdata an optional data frame with the test data.
#' @param ... extra arguments, currently none.
#'
#' @export split.caretx
#' @export

split.caretx <- function(x,
f,
drop = FALSE,
newdata = NULL, ...) {

## entry control -------

stopifnot(is_caretx(x))
stopifnot(is.logical(drop))

f <- rlang::as_string(rlang::ensym(f))

## augmented data --------

aug_data <- augment(x, newdata = newdata)

if(!f %in% names(aug_data[[1]])) {

stop("'f' not found in the training data.", call. = FALSE)

}

split_vecs <- map(aug_data, ~.x[[f]])

split_data <-
map2(aug_data, split_vecs, split, drop = drop)

split_data <- unlist(split_data, recursive = FALSE)

## predx objects --------

classes <- levels(split_data[[1]][['.outcome']])

split_data <- map(split_data,
select,
any_of(c('.observation',
'.resample',
'.outcome',
'.fitted',
classes)))

pred_types <-
stringi::stri_replace(names(split_data),
regex = '\\..*$',
replacement = '')

if(is.null(levels(split_data[[1]][['.outcome']]))) {

mod_type <- 'regression'

} else {

outcome_lens <- length(levels(split_data[[1]][['.outcome']]))

if(outcome_lens == 2) mod_type <- 'binary' else mod_type <- 'multi_class'

}

pmap(list(data = split_data,
prediction = pred_types),
predx,
classes = classes,
type = mod_type)

}

# END ------
5 changes: 3 additions & 2 deletions R/extraction.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
formula = formula(object),
resamples = object$resamples,
tuning = object$results,
best_tune = object$best_tune,
best_tune = object$bestTune,
prediction = predict(object, newdata = newdata, ...),
probability = classp(object, newdata = newdata, ...),
square_dist = squared(object, newdata = newdata, ...),
Expand Down Expand Up @@ -463,7 +463,8 @@

outcomes_num <-
as.data.frame(DescTools::Dummy(pred_data[['.outcome']],
method = 'full'))
method = 'full',
levels = x$classes))

fitted_num <- pred_data[x$classes]

Expand Down
1 change: 1 addition & 0 deletions R/imports.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Managing imports from dependencies

#' @importFrom generics components
#' @importFrom generics augment
#'
#' @importFrom caret calibration
#' @importFrom caret predict.train
Expand Down
25 changes: 20 additions & 5 deletions R/utils_analysis.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,19 @@

correlations <- purrr::transpose(correlations)$result

correlations$pearson <- c(rho = correlations$pearson$estimate,
lwr.ci = correlations$pearson$conf.int[1],
lwr.ci = correlations$pearson$conf.int[2])
if(!is.null(correlations$pearson)) {

correlations <-
map(correlations, function(x) if(is.null(x)) c(NA, NA, NA) else x)
correlations$pearson <- c(rho = correlations$pearson$estimate,
lwr.ci = correlations$pearson$conf.int[1],
lwr.ci = correlations$pearson$conf.int[2])

} else {

correlations$pearson <- c(rho = NA,
lwr.ci = NA,
lwr.ci = NA)

}

map2_dfr(correlations,
names(correlations),
Expand Down Expand Up @@ -476,6 +483,14 @@

splits <- split(data, factor(data[['.resample']]))

## eliminating splits with single observations
## no stats can be computed for them

splits <- map(splits,
function(x) if(nrow(x) > 1) x else NULL)

splits <- compact(splits)

split_stats <- map(splits, fun, ...)

est_tbl <- map2_dfc(split_stats, names(split_stats),
Expand Down
28 changes: 28 additions & 0 deletions man/augment.caretx.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 35 additions & 0 deletions man/split.caretx.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0adea0b

Please sign in to comment.