Skip to content

Commit

Permalink
ROC plots for response variable classes
Browse files Browse the repository at this point in the history
Implementation of the `clplots()` method
  • Loading branch information
PiotrTymoszuk authored Oct 11, 2023
1 parent c4103d4 commit ca7c2b3
Show file tree
Hide file tree
Showing 9 changed files with 503 additions and 36 deletions.
67 changes: 34 additions & 33 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
Package: caretExtra
Type: Package
Title: Extra functionality for predictions and quality control of caret models
Version: 1.1.2
Description: Tools for user-friendly prediction of outcome in training,
cross-validation and test data, plotting the prediction results and
calculation of error statistics.
License: GPL-3
Encoding: UTF-8
LazyData: true
Authors@R:
person("Piotr", "Tymoszuk", , "[email protected]",
role = c("aut", "cre"), comment = c(ORCID = "0000-0002-0398-6034"))
Imports:
coxed,
DescTools,
ggrepel,
plotROC,
qgam,
stringi,
survival
Depends:
caret,
clustTools,
dplyr,
generics,
ggplot2,
purrr,
rlang,
stats,
tibble
RoxygenNote: 7.2.3
Roxygen: list(markdown = TRUE)
Package: caretExtra
Type: Package
Title: Extra functionality for predictions and quality control of caret models
Version: 1.1.2
Description: Tools for user-friendly prediction of outcome in training,
cross-validation and test data, plotting the prediction results and
calculation of error statistics.
License: GPL-3
Encoding: UTF-8
LazyData: true
Authors@R:
person("Piotr", "Tymoszuk", , "[email protected]",
role = c("aut", "cre"), comment = c(ORCID = "0000-0002-0398-6034"))
Imports:
coxed,
DescTools,
ggrepel,
plotROC,
qgam,
scales,
stringi,
survival
Depends:
caret,
clustTools,
dplyr,
generics,
ggplot2,
purrr,
rlang,
stats,
tibble
RoxygenNote: 7.2.3
Roxygen: list(markdown = TRUE)
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ S3method(calibration,caretx)
S3method(calibration,predx)
S3method(classp,caretx)
S3method(classp,predx)
S3method(clplots,caretx)
S3method(clplots,predx)
S3method(clstats,caretx)
S3method(clstats,predx)
S3method(components,caretx)
Expand Down Expand Up @@ -36,6 +38,9 @@ export(caretx)
export(classp)
export(classp.caretx)
export(classp.predx)
export(clplots)
export(clplots.caretx)
export(clplots.predx)
export(clstats)
export(clstats.caretx)
export(clstats.predx)
Expand Down
56 changes: 55 additions & 1 deletion R/class_stats.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# computation of numeric metrics of the class identification quality.

# Numeric stats -------

#' Class detection quality.
#'
#' @description
Expand Down Expand Up @@ -34,7 +36,7 @@
#' receiver-operator characteristic as well as class-specific Brier scores
#' and average class assignment probabilities.
#' For `caretx` a list of such data frames, each one for the train, resample
#' and train data set.
#' and, it `newdata` is specified, the training data set.
#'
#' @export

Expand Down Expand Up @@ -79,4 +81,56 @@

}

# Class ROC plots --------

#' Receiver operating characteristic plots for outcome classes.
#'
#' @description
#' Draws receiver operating characteristic (ROC) plots for particular outcome
#' classes. The ROC statistics for particular classes are obtained by comparing
#' the given class the remaining ones (one versus rest comparison).
#'
#' @details
#' The function employs internally \code{\link[caret]{multiClassSummary}} and
#' plotting tools from the `plotROC` package. `clplots` is a S3 generic
#' function.
#'
#' @param one_plot logical: should all ROC curves be displayed in one plot?
#' @param ... extra arguments passed to \code{\link{plot_class_roc}}.
#' @inheritParams clstats
#'
#' @return a single `ggplot` object or a list of `ggplot` objects.
#' For `clplots.caretex` a list of `ggplots` with plots for the training data,
#' resamples and, if `newdata` is specified, also for the test data set.
#'
#' @export clplots

clplots <- function(x, ...) UseMethod('clplots')

#' @rdname clplots
#' @export clplots.predx
#' @export

clplots.predx <- function(x, one_plot = TRUE, ...) {

plot_class_roc(x, one_plot = one_plot, ...)

}

#' @rdname clplots
#' @export clplots.caretx
#' @export

clplots.caretx <- function(x,
newdata = NULL,
one_plot = TRUE, ...) {

stopifnot(is_caretx(x))

preds <- compact(predict(x, newdata = newdata))

map(preds, clplots, one_plot = one_plot, ...)

}

# END ------
5 changes: 4 additions & 1 deletion R/numbers.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@

## Number of observations -----

pred_data <- components(x, 'data')
.observation <- NULL

pred_data <- filter(components(x, 'data'),
!duplicated(.observation))

classes <- c('.fitted' = '.fitted',
'.outcome' = '.outcome')
Expand Down
Loading

0 comments on commit ca7c2b3

Please sign in to comment.