Skip to content

Commit

Permalink
GitHub examples and performance improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
PiotrTymoszuk authored Oct 11, 2023
1 parent 075690f commit 3e03aab
Show file tree
Hide file tree
Showing 12 changed files with 636 additions and 43 deletions.
66 changes: 33 additions & 33 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
Package: caretExtra
Type: Package
Title: Extra functionality for predictions and quality control of caret models
Version: 1.1.1
Description: Tools for user-friendly prediction of outcome in training,
cross-validation and test data, plotting the prediction results and
calculation of error statistics.
License: GPL-3
Encoding: UTF-8
LazyData: true
Authors@R:
person("Piotr", "Tymoszuk", , "[email protected]",
role = c("aut", "cre"), comment = c(ORCID = "0000-0002-0398-6034"))
Imports:
coxed,
DescTools,
ggrepel,
plotROC,
qgam,
stringi,
survival
Depends:
caret,
clustTools,
dplyr,
generics,
ggplot2,
purrr,
rlang,
stats,
tibble
RoxygenNote: 7.2.3
Roxygen: list(markdown = TRUE)
Package: caretExtra
Type: Package
Title: Extra functionality for predictions and quality control of caret models
Version: 1.1.2
Description: Tools for user-friendly prediction of outcome in training,
cross-validation and test data, plotting the prediction results and
calculation of error statistics.
License: GPL-3
Encoding: UTF-8
LazyData: true
Authors@R:
person("Piotr", "Tymoszuk", , "[email protected]",
role = c("aut", "cre"), comment = c(ORCID = "0000-0002-0398-6034"))
Imports:
coxed,
DescTools,
ggrepel,
plotROC,
qgam,
stringi,
survival
Depends:
caret,
clustTools,
dplyr,
generics,
ggplot2,
purrr,
rlang,
stats,
tibble
RoxygenNote: 7.2.3
Roxygen: list(markdown = TRUE)
4 changes: 3 additions & 1 deletion R/builders.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

caretx <- function(caret_model) {

if(is_caretx(caret_model)) return(caret_model)

## user entry control -------

if(is.null(caret_model$trainingData)) {
Expand Down Expand Up @@ -171,7 +173,7 @@

is_predx <- function(x) {

any(class(x) == 'predx')
inherits(x, 'predx')

}

Expand Down
4 changes: 2 additions & 2 deletions R/caretx_oop.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@
#'
#' @description
#' The `augment()` method for \code{\link{caretx}} objects derives the
#' predictions for the training, resample (CV), and, optionally, training
#' predictions for the training, resample (CV), and, optionally, test
#' data set appended with the explanatory variables.
#'
#' @param x a \code{\link{caretx}} object.
#' @param newdata an optional data frame with the test data.
#' @param ... extra arguments, currently none.
#'
#' @return alist of data frames. Each of them contains the
#' @return a list of data frames. Each of them contains the
#' observation number (`.observation`),
#' resample ID (only for resample/CV, `.resample`),
#' outcome (`.outcome`), fitted values/classes (`.fitted`).
Expand Down
12 changes: 11 additions & 1 deletion R/plotting.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
', R\u00B2 = ', signif(unlist(stats[4, 2]),
signif_digits))

} else {
} else if(x$type == 'binary') {

plot_subtitle <-
paste0('Acc = ', signif(unlist(stats[5, 2]),
Expand All @@ -190,6 +190,16 @@
', BS = ', signif(unlist(stats[16, 2]),
signif_digits))

} else {

plot_subtitle <-
paste0('Acc = ', signif(unlist(stats[5, 2]),
signif_digits),
', \u03BA = ', signif(unlist(stats[6, 2]),
signif_digits),
', BS = ', signif(unlist(stats[15, 2]),
signif_digits))

}

}
Expand Down
2 changes: 1 addition & 1 deletion R/summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#' Pseudo R squared is calculated as 1 - MSE/Var(y).
#' Pearson correlation is obtained with \code{\link[stats]{cor.test}},
#' Spearman correlation is computed with \code{\link[DescTools]{SpearmanRho}},
#' Kendall's TauB is obtained with \code{\link[DescTools]{KendallTauB}}.
#' Kendall's TauB is obtained with \code{\link[stats]{cor}}.
#' For cross-validation (CV) prediction, statistic values are calculated as
#' mean across the CV with 95\% confidence intervals (CI).
#' For multi-class predictions and models, statistics referring to
Expand Down
4 changes: 2 additions & 2 deletions R/utils_analysis.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

resamplCor <-
try(stats::cor(data[['.outcome']], data[['.fitted']],
use = 'pairwise.complete.obs'),
use = 'pairwise.complete.obs'),
silent = TRUE)

resids <- get_resids(data)
Expand Down Expand Up @@ -77,7 +77,7 @@
correlations <-
list(pearson = function(x, y) stats::cor.test(x, y, method = 'pearson', conf.level = 0.95),
spearman = function(x, y) DescTools::SpearmanRho(x, y, conf.level = if(ci) 0.95 else NA),
kendall = function(x, y) DescTools::KendallTauB(x, y, conf.level = if(ci) 0.95 else NA))
kendall = function(x, y) c(stats::cor(x, y, method = 'kendall'), NA, NA))

correlations <- map(correlations,
~safely(.x)(data[['.outcome']], data[['.fitted']]))
Expand Down
175 changes: 175 additions & 0 deletions inst/examples/_develop.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@

library(caret)
library(caretExtra)

# Development and testing stuff -----

testControl <- caret::trainControl(method = 'repeatedcv',
number = 10,
repeats = 5,
returnData = TRUE,
returnResamp = 'final',
savePredictions = 'final',
classProbs = TRUE)

devData_class <- tibble::as_tibble(MASS::biopsy)
devData_class <- dplyr::filter(devData_class, complete.cases(devData_class))


devData_corr <- tibble::as_tibble(MASS::birthwt)
devData_corr <- dplyr::filter(devData_corr, complete.cases(devData_corr))

devData_multi <- dplyr::mutate(mtcars,
.cyl = paste0('cyl_', cyl),
.cyl = factor(.cyl))

devData_multi <- dplyr::filter(devData_multi,
complete.cases(devData_multi))

trainClassIDs <- sample(1:nrow(devData_class), 500, replace = FALSE)
trainCorrIDs <- sample(1:nrow(devData_corr), 120, replace = FALSE)
trainMultiIDs <- sample(1:nrow(devData_multi), 20, replace = FALSE)

trainClass <- devData_class[trainClassIDs, ]
testClass <- devData_class[-trainClassIDs, ]

trainCorr <- devData_corr[trainCorrIDs, ]
testCorr <- devData_corr[-trainCorrIDs, ]

trainMulti <- devData_multi[trainMultiIDs, ]
testMulti <- devData_multi[-trainMultiIDs, ]

class_form <- class ~ V1 + V2 + V3 + V4 + V5 + V6 + V7 + V8 + V9

corr_form <- bwt ~ age + lwt + race + smoke + ptl + ht + ui + ftv

multi_form <- .cyl ~ mpg + disp + hp + drat + wt + qsec + gear + carb

# Development models -----

doParallel::registerDoParallel(cores = 7)

class_model <- caret::train(form = class_form,
data = trainClass,
method = 'nnet',
metric = 'Kappa',
trControl = testControl)

corr_models <- caret::train(form = corr_form,
data = trainCorr,
method = 'rf',
metric = 'MAE',
trControl = testControl)

multi_models <- caret::train(form = multi_form,
data = trainMulti,
method = 'nnet',
metric = 'Kappa',
trControl = testControl)

doParallel::stopImplicitCluster()

# testing the toolbox -----

## builder

caretx_class <- caretx(class_model)

caretx_corr <- caretx(corr_models)

caretx_multi <- caretx(multi_models)

## predictions, predx objects

test_class_pred <- predict(caretx_class, newdata = testClass)

test_corr_pred <- predict(caretx_corr, newdata = testCorr)

test_multi_pred <- predict(caretx_multi, newdata = testMulti)

## prediction, caretx models

predict(caretx_class, newdata = testClass, plain = TRUE)

predict(caretx_corr, newdata = testCorr, plain = TRUE)

predict(caretx_multi, newdata = testMulti, plain = TRUE)

## summary, predx objects

summary(test_class_pred$cv, ci_method = 'bca')

summary(test_corr_pred$cv, ci_method = 'percentile')

summary(test_multi_pred$cv, ci_method = 'norm')

## summary, caretx models

summary(caretx_class, newdata = testClass)

summary(caretx_corr, newdata = testCorr)

summary(caretx_multi, newdata = testMulti)

## extractors

nobs(caretx_corr)

## model QC

residuals(caretx_class, newdata = testClass)

residuals(caretx_corr, newdata = testCorr)

residuals(caretx_multi)

# confusion matrix

confusion(test_class_pred$train, scale = 'fraction')

confusion(test_corr_pred$train)

confusion(test_multi_pred$cv)

confusion(caretx_class, scale = 'fraction', newdata = testClass)

confusion(caretx_multi, scale = 'none', newdata = testMulti)

## extractor

components(caretx_corr, newdata = testCorr, what = 'fit')

## plotting of the fitted values

plot(x = test_corr_pred$test, type = 'regression')

plot(x = test_class_pred$test, type = 'confusion')

plot(x = test_multi_pred$test, type = 'fit')

plot(caretx_class, type = 'fit', newdata = testClass, plot_title = c('Training', 'CV', 'Test'))

plot(caretx_corr,
type = 'fit',
newdata = testCorr,
plot_title = c('Training', 'CV', 'Test'),
cust_theme = ggplot2::theme_light() + theme(plot.tag.position = 'bottom'))

plot(caretx_multi,
type = 'fit',
newdata = testMulti,
plot_title = c('Training', 'CV', 'Test'),
cust_theme = ggplot2::theme_light() + theme(plot.tag.position = 'bottom'))

## calibration

calibration(caretx_corr, qu = 0.5)

test_cal <- calibration(caretx_corr,
newdata = testCorr,
qu = c(0.2, 0.4, 0.6))

purrr::map(test_cal[c("train", "cv", "test")],
plot, 'fit')

# END ------
Loading

0 comments on commit 3e03aab

Please sign in to comment.