diff --git a/DESCRIPTION b/DESCRIPTION index 306981847..a469f1218 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: tune Title: Tidy Tuning Tools -Version: 1.1.1 +Version: 1.1.2 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-2402-136X")), @@ -31,7 +31,7 @@ Imports: purrr (>= 1.0.0), recipes (>= 1.0.4), rlang (>= 1.0.2), - rsample (>= 1.0.0), + rsample (>= 1.1.1.9001), tibble (>= 3.1.0), tidyr (>= 1.2.0), tidyselect (>= 1.1.2), diff --git a/NEWS.md b/NEWS.md index c20cf1e66..67015951e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,9 @@ +# tune 1.1.2 + +* `last_fit()` now works with the 3-way validation split objects from `rsample::initial_validation_split()`. `last_fit()` and `fit_best()` now have a new argument `add_validation_set` to include or exclude the validation set in the dataset used to fit the model (#701). + +* Disambiguates the `verbose` and `verbose_iter` control options to better align with documented functionality. The former controls logging for general progress updates, while the latter only does so for the Bayesian search process. (#682) + # tune 1.1.1 * Fixed a bug introduced in tune 1.1.0 in `collect_()` functions where the diff --git a/R/fit_best.R b/R/fit_best.R index e1951abef..79468bc87 100644 --- a/R/fit_best.R +++ b/R/fit_best.R @@ -14,6 +14,12 @@ #' If `NULL`, this argument will be set to #' [`select_best(metric)`][tune::select_best.tune_results]. #' @param verbose A logical for printing logging. +#' @param add_validation_set When the resamples embedded in `x` are a split into +#' training set and validation set, should the validation set be included in the +#' data set used to train the model? If not, only the training set is used. If +#' `NULL`, the validation set is not used for resamples originating from +#' [rsample::validation_set()] while it is used for resamples originating +#' from [rsample::validation_split()]. #' @param ... Not currently used. #' @details #' This function is a shortcut for the manual steps of: @@ -24,10 +30,6 @@ #' wflow_fit <- fit(wflow, data_set) #' } #' -#' The data used for the fit are taken from the `splits` column. If the split -#' column was from a validation split, the combined training and validation sets -#' are used. -#' #' In comparison to [last_fit()], that function requires a finalized model, fits #' the model on the training set defined by [rsample::initial_split()], and #' computes metrics from the test set. @@ -85,6 +87,7 @@ fit_best.tune_results <- function(x, metric = NULL, parameters = NULL, verbose = FALSE, + add_validation_set = NULL, ...) { if (length(list(...))) { cli::cli_abort(c("x" = "The `...` are not used by this function.")) @@ -121,7 +124,29 @@ fit_best.tune_results <- function(x, # ---------------------------------------------------------------------------- - dat <- x$splits[[1]]$data + if (inherits(x$splits[[1]], "val_split")) { + if (is.null(add_validation_set)) { + rset_info <- attr(x, "rset_info") + originate_from_3way_split <- rset_info$att$origin_3way %||% FALSE + if (originate_from_3way_split) { + add_validation_set <- FALSE + } else { + add_validation_set <- TRUE + } + } + if (add_validation_set) { + dat <- x$splits[[1]]$data + } else { + dat <- rsample::training(x$splits[[1]]) + } + } else { + if (!is.null(add_validation_set)) { + rlang::warn( + "The option `add_validation_set` is being ignored because the resampling object does not include a validation set." + ) + } + dat <- x$splits[[1]]$data + } if (verbose) { cli::cli_inform(c("i" = "Fitting using {nrow(dat)} data points...")) } diff --git a/R/grid_performance.R b/R/grid_performance.R index 8d73dfa58..a0f600e35 100644 --- a/R/grid_performance.R +++ b/R/grid_performance.R @@ -48,6 +48,10 @@ metrics_info <- function(x) { #' @param new_data A data frame or matrix of predictors to process. #' @param metrics_info The output of `tune:::metrics_info(metrics)`---only #' included as an argument to allow for pre-computing. +#' @param catalog A logical passed to `tune_log()` giving whether the message +#' is compatible with the issue cataloger. Defaults to `TRUE`. Updates that are +#' always unique and do not represent a tuning "issue" can bypass the cataloger +#' by setting `catalog = FALSE`. #' @keywords internal #' @name tune-internal-functions #' @export diff --git a/R/last_fit.R b/R/last_fit.R index 910599367..709496e32 100644 --- a/R/last_fit.R +++ b/R/last_fit.R @@ -10,7 +10,8 @@ #' @param preprocessor A traditional model formula or a recipe created using #' [recipes::recipe()]. #' -#' @param split An `rsplit` object created from [rsample::initial_split()]. +#' @param split An `rsplit` object created from [rsample::initial_split()] or +#' [rsample::initial_validation_split()]. #' #' @param metrics A [yardstick::metric_set()], or `NULL` to compute a standard #' set of metrics. @@ -18,6 +19,11 @@ #' @param control A [control_last_fit()] object used to fine tune the last fit #' process. #' +#' @param add_validation_set For 3-way splits into training, validation, and test +#' set via [rsample::initial_validation_split()], should the validation set be +#' included in the data set used to train the model. If not, only the training +#' set is used. +#' #' @param ... Currently unused. #' #' @details @@ -73,7 +79,9 @@ last_fit.default <- function(object, ...) { #' @export #' @rdname last_fit -last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL, control = control_last_fit()) { +last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL, + control = control_last_fit(), + add_validation_set = FALSE) { if (rlang::is_missing(preprocessor) || !is_preprocessor(preprocessor)) { rlang::abort(paste( "To tune a model spec, you must preprocess", @@ -93,22 +101,27 @@ last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL wflow <- add_formula(wflow, preprocessor) } - last_fit_workflow(wflow, split, metrics, control) + last_fit_workflow(wflow, split, metrics, control, add_validation_set) } #' @rdname last_fit #' @export -last_fit.workflow <- function(object, split, ..., metrics = NULL, control = control_last_fit()) { +last_fit.workflow <- function(object, split, ..., metrics = NULL, + control = control_last_fit(), + add_validation_set = FALSE) { empty_ellipses(...) control <- parsnip::condense_control(control, control_last_fit()) - last_fit_workflow(object, split, metrics, control) + last_fit_workflow(object, split, metrics, control, add_validation_set) } -last_fit_workflow <- function(object, split, metrics, control) { +last_fit_workflow <- function(object, split, metrics, control, add_validation_set) { check_no_tuning(object) + if (inherits(split, "initial_validation_split")) { + split <- prepare_validation_split(split, add_validation_set) + } splits <- list(split) resamples <- rsample::manual_rset(splits, ids = "train/test split") @@ -132,3 +145,32 @@ last_fit_workflow <- function(object, split, metrics, control) { .stash_last_result(res) res } + + +prepare_validation_split <- function(split, add_validation_set){ + if (add_validation_set) { + # equivalent to (unexported) rsample:::rsplit() without checks + split <- structure( + list( + data = split$data, + in_id = c(split$train_id, split$val_id), + out_id = NA + ), + class = "rsplit" + ) + } else { + id_train_test <- seq_len(nrow(split$data))[-sort(split$val_id)] + id_train <- match(split$train_id, id_train_test) + + split <- structure( + list( + data = split$data[-sort(split$val_id), , drop = FALSE], + in_id = id_train, + out_id = NA + ), + class = "rsplit" + ) + } + + split +} diff --git a/R/logging.R b/R/logging.R index f8f0f2cbd..653f70fbc 100644 --- a/R/logging.R +++ b/R/logging.R @@ -256,12 +256,12 @@ siren <- function(x, type = "info") { } -tune_log <- function(control, split = NULL, task, type = "success") { - if (!control$verbose) { +tune_log <- function(control, split = NULL, task, type = "success", catalog = TRUE) { + if (!any(control$verbose, control$verbose_iter)) { return(invisible(NULL)) } - if (uses_catalog()) { + if (uses_catalog() & catalog) { log_catalog(task, type) return(NULL) } @@ -291,6 +291,7 @@ log_problems <- function(notes, control, split, loc, res, bad_only = FALSE) { # Always log warnings and errors control2 <- control control2$verbose <- TRUE + control2$verbose_iter <- TRUE should_catalog <- uses_catalog() @@ -361,8 +362,8 @@ format_msg <- function(loc, msg) { #' @export #' @rdname tune-internal-functions -.catch_and_log <- function(.expr, ..., bad_only = FALSE, notes) { - tune_log(..., type = "info") +.catch_and_log <- function(.expr, ..., bad_only = FALSE, notes, catalog = TRUE) { + tune_log(..., type = "info", catalog = catalog) tmp <- catcher(.expr) new_notes <- log_problems(notes, ..., tmp, bad_only = bad_only) assign("out_notes", new_notes, envir = parent.frame()) @@ -410,7 +411,7 @@ format_msg <- function(loc, msg) { } log_best <- function(control, iter, info, digits = 4) { - if (!control$verbose) { + if (!isTRUE(control$verbose_iter)) { return(invisible(NULL)) } @@ -428,16 +429,22 @@ log_best <- function(control, iter, info, digits = 4) { info$best_iter, ")" ) - tune_log(control, split = NULL, task = msg, type = "info") + tune_log(control, split = NULL, task = msg, type = "info", catalog = FALSE) } check_and_log_flow <- function(control, results) { + if (!isTRUE(control$verbose_iter)) { + return(invisible(NULL)) + } + if (all(is.na(results$.mean))) { if (nrow(results) < 2) { - tune_log(control, split = NULL, task = "Halting search", type = "danger") + tune_log(control, split = NULL, task = "Halting search", + type = "danger", catalog = FALSE) eval.parent(parse(text = "break")) } else { - tune_log(control, split = NULL, task = "Skipping to next iteration", type = "danger") + tune_log(control, split = NULL, task = "Skipping to next iteration", + type = "danger", catalog = FALSE) eval.parent(parse(text = "next")) } } @@ -445,7 +452,7 @@ check_and_log_flow <- function(control, results) { } log_progress <- function(control, x, maximize = TRUE, objective = NULL, digits = 4) { - if (!control$verbose) { + if (!isTRUE(control$verbose_iter)) { return(invisible(NULL)) } @@ -479,7 +486,7 @@ log_progress <- function(control, x, maximize = TRUE, objective = NULL, digits = } param_msg <- function(control, candidate) { - if (!control$verbose) { + if (!isTRUE(control$verbose_iter)) { return(invisible(NULL)) } candidate <- candidate[, !(names(candidate) %in% c(".mean", ".sd", "objective"))] @@ -495,7 +502,7 @@ param_msg <- function(control, candidate) { acq_summarizer <- function(control, iter, objective = NULL, digits = 4) { - if (!control$verbose) { + if (!isTRUE(control$verbose_iter)) { return(invisible(NULL)) } if (inherits(objective, "conf_bound") && is.function(objective$kappa)) { @@ -509,7 +516,7 @@ acq_summarizer <- function(control, iter, objective = NULL, digits = 4) { } } if (!is.null(val)) { - tune_log(control, split = NULL, task = val, type = "info") + tune_log(control, split = NULL, task = val, type = "info", catalog = FALSE) } invisible(NULL) } diff --git a/R/tune_bayes.R b/R/tune_bayes.R index ae47a4f29..399d989a7 100644 --- a/R/tune_bayes.R +++ b/R/tune_bayes.R @@ -342,7 +342,8 @@ tune_bayes_workflow <- control, NULL, "Gaussian process model", - notes = .notes + notes = .notes, + catalog = FALSE ) gp_mod <- check_gp_failure(gp_mod, prev_gp_mod) @@ -563,13 +564,14 @@ pred_gp <- function(object, pset, size = 5000, current = NULL, control) { control, split = NULL, task = paste("Generating", nrow(pred_grid), "candidates"), - type = "info" + type = "info", + catalog = FALSE ) x <- encode_set(pred_grid, pset, as_matrix = TRUE) gp_pred <- predict(object, x) - tune_log(control, split = NULL, task = "Predicted candidates", type = "info") + tune_log(control, split = NULL, task = "Predicted candidates", type = "info", catalog = FALSE) pred_grid %>% dplyr::mutate(.mean = gp_pred$Y_hat, .sd = sqrt(gp_pred$MSE)) diff --git a/man/fit_best.Rd b/man/fit_best.Rd index 794c969bf..38a038426 100644 --- a/man/fit_best.Rd +++ b/man/fit_best.Rd @@ -10,7 +10,14 @@ fit_best(x, ...) \method{fit_best}{default}(x, ...) -\method{fit_best}{tune_results}(x, metric = NULL, parameters = NULL, verbose = FALSE, ...) +\method{fit_best}{tune_results}( + x, + metric = NULL, + parameters = NULL, + verbose = FALSE, + add_validation_set = NULL, + ... +) } \arguments{ \item{x}{The results of class \code{tune_results} (coming from functions such as @@ -29,6 +36,13 @@ If \code{NULL}, this argument will be set to \code{\link[=select_best.tune_results]{select_best(metric)}}.} \item{verbose}{A logical for printing logging.} + +\item{add_validation_set}{When the resamples embedded in \code{x} are a split into +training set and validation set, should the validation set be included in the +data set used to train the model? If not, only the training set is used. If +\code{NULL}, the validation set is not used for resamples originating from +\code{\link[rsample:validation_set]{rsample::validation_set()}} while it is used for resamples originating +from \code{\link[rsample:validation_split]{rsample::validation_split()}}.} } \value{ A fitted workflow. @@ -46,10 +60,6 @@ This function is a shortcut for the manual steps of: wflow_fit <- fit(wflow, data_set) } -The data used for the fit are taken from the \code{splits} column. If the split -column was from a validation split, the combined training and validation sets -are used. - In comparison to \code{\link[=last_fit]{last_fit()}}, that function requires a finalized model, fits the model on the training set defined by \code{\link[rsample:initial_split]{rsample::initial_split()}}, and computes metrics from the test set. diff --git a/man/last_fit.Rd b/man/last_fit.Rd index 2e0bed05d..aa8c1cb7e 100644 --- a/man/last_fit.Rd +++ b/man/last_fit.Rd @@ -14,10 +14,18 @@ last_fit(object, ...) split, ..., metrics = NULL, - control = control_last_fit() + control = control_last_fit(), + add_validation_set = FALSE ) -\method{last_fit}{workflow}(object, split, ..., metrics = NULL, control = control_last_fit()) +\method{last_fit}{workflow}( + object, + split, + ..., + metrics = NULL, + control = control_last_fit(), + add_validation_set = FALSE +) } \arguments{ \item{object}{A \code{parsnip} model specification or a \code{\link[workflows:workflow]{workflows::workflow()}}. @@ -28,13 +36,19 @@ No tuning parameters are allowed.} \item{preprocessor}{A traditional model formula or a recipe created using \code{\link[recipes:recipe]{recipes::recipe()}}.} -\item{split}{An \code{rsplit} object created from \code{\link[rsample:initial_split]{rsample::initial_split()}}.} +\item{split}{An \code{rsplit} object created from \code{\link[rsample:initial_split]{rsample::initial_split()}} or +\code{\link[rsample:initial_validation_split]{rsample::initial_validation_split()}}.} \item{metrics}{A \code{\link[yardstick:metric_set]{yardstick::metric_set()}}, or \code{NULL} to compute a standard set of metrics.} \item{control}{A \code{\link[=control_last_fit]{control_last_fit()}} object used to fine tune the last fit process.} + +\item{add_validation_set}{For 3-way splits into training, validation, and test +set via \code{\link[rsample:initial_validation_split]{rsample::initial_validation_split()}}, should the validation set be +included in the data set used to train the model. If not, only the training +set is used.} } \value{ A single row tibble that emulates the structure of \code{fit_resamples()}. diff --git a/man/tune-internal-functions.Rd b/man/tune-internal-functions.Rd index 8724d163e..e8d09365d 100644 --- a/man/tune-internal-functions.Rd +++ b/man/tune-internal-functions.Rd @@ -29,7 +29,7 @@ finalize_workflow_preprocessor(workflow, grid_preprocessor) initialize_catalog(control, env = rlang::caller_env()) -.catch_and_log(.expr, ..., bad_only = FALSE, notes) +.catch_and_log(.expr, ..., bad_only = FALSE, notes, catalog = TRUE) .catch_and_log_fit(.expr, ..., notes) } @@ -63,6 +63,11 @@ included as an argument to allow for pre-computing.} \item{bad_only}{A logical for whether warnings and errors should be caught.} \item{notes}{Character data to add to the logging.} + +\item{catalog}{A logical passed to \code{tune_log()} giving whether the message +is compatible with the issue cataloger. Defaults to \code{TRUE}. Updates that are +always unique and do not represent a tuning "issue" can bypass the cataloger +by setting \code{catalog = FALSE}.} } \description{ These are not to be meant to be invoked directly by users. diff --git a/tests/testthat/_snaps/bayes.md b/tests/testthat/_snaps/bayes.md index 2745f7e17..c9ce46329 100644 --- a/tests/testthat/_snaps/bayes.md +++ b/tests/testthat/_snaps/bayes.md @@ -8,10 +8,6 @@ > Generating a set of 2 initial parameter results v Initialization complete - - -- Iteration 1 ----------------------------------------------------------------- - - i Current best: rmse=2.418 (@iter 0) i Gaussian process model ! The Gaussian process model is being fit using 1 features but only has 2 data points to do so. This may cause errors or a poor model fit. @@ -19,7 +15,6 @@ v Gaussian process model i Generating 3 candidates i Predicted candidates - i num_comp=2 i Estimating performance i Fold01: preprocessor 1/1 v Fold01: preprocessor 1/1 @@ -82,17 +77,11 @@ i Fold10: preprocessor 1/1, model 1/1 (extracts) i Fold10: preprocessor 1/1, model 1/1 (predictions) v Estimating performance - (x) Newest results: rmse=2.666 (+/-0.281) - - -- Iteration 2 ----------------------------------------------------------------- - - i Current best: rmse=2.418 (@iter 0) i Gaussian process model ! Gaussian process model: X should be in range (0, 1) v Gaussian process model i Generating 2 candidates i Predicted candidates - i num_comp=5 i Estimating performance i Fold01: preprocessor 1/1 v Fold01: preprocessor 1/1 @@ -155,7 +144,6 @@ i Fold10: preprocessor 1/1, model 1/1 (extracts) i Fold10: preprocessor 1/1, model 1/1 (predictions) v Estimating performance - (x) Newest results: rmse=2.453 (+/-0.381) Output # Tuning results # 10-fold cross-validation @@ -181,10 +169,34 @@ control = control_bayes(verbose_iter = TRUE)) Message Optimizing rmse using the expected improvement + + -- Iteration 1 ----------------------------------------------------------------- + + i Current best: rmse=2.418 (@iter 0) + i Gaussian process model ! The Gaussian process model is being fit using 1 features but only has 2 data points to do so. This may cause errors or a poor model fit. ! Gaussian process model: X should be in range (0, 1) + v Gaussian process model + i Generating 3 candidates + i Predicted candidates + i num_comp=4 + i Estimating performance + v Estimating performance + (x) Newest results: rmse=2.461 (+/-0.37) + + -- Iteration 2 ----------------------------------------------------------------- + + i Current best: rmse=2.418 (@iter 0) + i Gaussian process model ! Gaussian process model: X should be in range (0, 1) + v Gaussian process model + i Generating 2 candidates + i Predicted candidates + i num_comp=5 + i Estimating performance + v Estimating performance + (x) Newest results: rmse=2.453 (+/-0.381) Output # Tuning results # 10-fold cross-validation @@ -511,68 +523,28 @@ Bldg_Type + Latitude + Longitude, resamples = folds, initial = 3, metrics = yardstick::metric_set( rsq), param_info = parameters(dials::cost_complexity(c(-2, 0)))) Message - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 0.9998726`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... ! For the rsq estimates, 1 missing value was found and removed before fitting the Gaussian process model. - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 0.9995734`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... ! For the rsq estimates, 2 missing values were found and removed before fitting the Gaussian process model. - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 0.9995`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... ! For the rsq estimates, 3 missing values were found and removed before fitting the Gaussian process model. - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 0.9992439`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... ! For the rsq estimates, 4 missing values were found and removed before fitting the Gaussian process model. - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 0.9996399`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... ! For the rsq estimates, 5 missing values were found and removed before fitting the Gaussian process model. - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 0.9994867`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... ! For the rsq estimates, 6 missing values were found and removed before fitting the Gaussian process model. - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 0.9997809`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... ! For the rsq estimates, 7 missing values were found and removed before fitting the Gaussian process model. - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 0.9999164`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... --- @@ -583,36 +555,11 @@ metrics = yardstick::metric_set(rsq), param_info = parameters(dials::cost_complexity( c(0.5, 0)))) Message - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 1.95966`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 2.500534`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 1.263317`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 2.792327`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... - ! validation: internal: - There was 1 warning in `dplyr::summarise()`. - i In argument: `.estimate = metric_fn(truth = Sale_Price, estimate = .... - i In group 1: `cost_complexity = 1.224671`. - Caused by warning: - ! A correlation computation is required, but `estimate` is constant an... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... + ! validation: internal: A correlation computation is required, but `estimate` is constant and ha... ! All of the rsq estimates were missing. The Gaussian process model cannot be fit to the data. ! Gaussian process model: no non-missing arguments to min; returning Inf, no non-missing arguments... diff --git a/tests/testthat/_snaps/resample.md b/tests/testthat/_snaps/resample.md index 0fb804e8d..df010e0ff 100644 --- a/tests/testthat/_snaps/resample.md +++ b/tests/testthat/_snaps/resample.md @@ -39,10 +39,10 @@ Message x Fold1: preprocessor 1/1, model 1/1: Error in `check_outcome()`: - ! For a classification model, the outcome should be a factor. + ! For a classification model, the outcome should be a `factor`, not a ... x Fold2: preprocessor 1/1, model 1/1: Error in `check_outcome()`: - ! For a classification model, the outcome should be a factor. + ! For a classification model, the outcome should be a `factor`, not a ... Condition Warning: All models failed. Run `show_notes(.Last.tune.result)` for more information. diff --git a/tests/testthat/test-compat-dplyr.R b/tests/testthat/test-compat-dplyr.R index 04d6b9230..f180a999e 100644 --- a/tests/testthat/test-compat-dplyr.R +++ b/tests/testthat/test-compat-dplyr.R @@ -291,7 +291,9 @@ test_that("left_join() can keep tune_results class if tune_results structure is test_that("left_join() can lose tune_results class if rows are added", { for (x in helper_tune_results) { y <- tibble(id = x$id[[1]], x = 1:2) - expect_s3_class_bare_tibble(left_join(x, y, by = "id", multiple = "all")) + expect_s3_class_bare_tibble( + left_join(x, y, by = "id", relationship = "many-to-many") + ) } }) @@ -318,7 +320,9 @@ test_that("right_join() can keep tune_results class if tune_results structure is test_that("right_join() can lose tune_results class if rows are added", { for (x in helper_tune_results) { y <- tibble(id = x$id[[1]], x = 1:2) - expect_s3_class_bare_tibble(right_join(x, y, by = "id", multiple = "all")) + expect_s3_class_bare_tibble( + right_join(x, y, by = "id", relationship = "many-to-many") + ) } }) diff --git a/tests/testthat/test-extract-helpers.R b/tests/testthat/test-extract-helpers.R index a86f4fc34..64dcb6bb4 100644 --- a/tests/testthat/test-extract-helpers.R +++ b/tests/testthat/test-extract-helpers.R @@ -24,7 +24,7 @@ test_that("extract methods for resample_results objects", { recipes::step_normalize(recipes::all_numeric_predictors())) lm_rec_res <- fit_resamples( lm_rec_wflow, - resamples = rsample::vfold_cv(mtcars, V = 2), + resamples = rsample::vfold_cv(mtcars, v = 2), control = control_resamples(save_workflow = TRUE) ) diff --git a/tests/testthat/test-fit_best.R b/tests/testthat/test-fit_best.R index 5e8995e0f..f598f6550 100644 --- a/tests/testthat/test-fit_best.R +++ b/tests/testthat/test-fit_best.R @@ -63,3 +63,68 @@ test_that("fit_best", { fit_best(ames_iter_search) ) }) + +test_that("fit_best() works with validation split: 3-way split", { + skip_if_not_installed("kknn") + skip_if_not_installed("modeldata") + data(ames, package = "modeldata", envir = rlang::current_env()) + + set.seed(23598723) + initial_val_split <- rsample::initial_validation_split(ames) + val_set <- validation_set(initial_val_split) + + f <- Sale_Price ~ Gr_Liv_Area + Year_Built + knn_mod <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression") + wflow <- workflow(f, knn_mod) + + tune_res <- tune_grid( + wflow, + grid = tibble(neighbors = c(1, 5)), + resamples = val_set, + control = control_grid(save_workflow = TRUE) + ) %>% suppressWarnings() + set.seed(3) + fit_on_train <- fit_best(tune_res) + pred <- predict(fit_on_train, testing(initial_val_split)) + + set.seed(3) + exp_fit_on_train <- nearest_neighbor(neighbors = 5) %>% + set_mode("regression") %>% + fit(f, training(initial_val_split)) + exp_pred <- predict(exp_fit_on_train, testing(initial_val_split)) + + expect_equal(pred, exp_pred) +}) + +test_that("fit_best() works with validation split: 2x 2-way splits", { + skip_if_not_installed("kknn") + skip_if_not_installed("modeldata") + data(ames, package = "modeldata", envir = rlang::current_env()) + + set.seed(23598723) + split <- rsample::initial_split(ames) + train_and_val <- training(split) + val_set <- rsample::validation_split(train_and_val) + + f <- Sale_Price ~ Gr_Liv_Area + Year_Built + knn_mod <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression") + wflow <- workflow(f, knn_mod) + + tune_res <- tune_grid( + wflow, + grid = tibble(neighbors = c(1, 5)), + resamples = val_set, + control = control_grid(save_workflow = TRUE) + ) + set.seed(3) + fit_on_train_and_val <- fit_best(tune_res) + pred <- predict(fit_on_train_and_val, testing(split)) + + set.seed(3) + exp_fit_on_train_and_val <- nearest_neighbor(neighbors = 5) %>% + set_mode("regression") %>% + fit(f, train_and_val) + exp_pred <- predict(exp_fit_on_train_and_val, testing(split)) + + expect_equal(pred, exp_pred) +}) diff --git a/tests/testthat/test-last-fit.R b/tests/testthat/test-last-fit.R index a974799ce..445a8d27f 100644 --- a/tests/testthat/test-last-fit.R +++ b/tests/testthat/test-last-fit.R @@ -146,3 +146,68 @@ test_that("`last_fit()` when objects need tuning", { expect_snapshot_error(last_fit(wflow_2, split)) expect_snapshot_error(last_fit(wflow_3, split)) }) + +test_that("last_fit() excludes validation set for initial_validation_split objects", { + skip_if_not_installed("modeldata") + data(ames, package = "modeldata", envir = rlang::current_env()) + + set.seed(23598723) + split <- rsample::initial_validation_split(ames) + + f <- Sale_Price ~ Gr_Liv_Area + Year_Built + lm_fit <- lm(f, data = rsample::training(split)) + test_pred <- predict(lm_fit, rsample::testing(split)) + rmse_test <- yardstick::rsq_vec(rsample::testing(split) %>% pull(Sale_Price), test_pred) + + res <- parsnip::linear_reg() %>% + parsnip::set_engine("lm") %>% + last_fit(f, split) + + expect_equal(res, .Last.tune.result) + + expect_equal( + coef(extract_fit_engine(res$.workflow[[1]])), + coef(lm_fit), + ignore_attr = TRUE + ) + expect_equal(res$.metrics[[1]]$.estimate[[2]], rmse_test) + expect_equal(res$.predictions[[1]]$.pred, unname(test_pred)) + expect_true(res$.workflow[[1]]$trained) + expect_equal( + nrow(predict(res$.workflow[[1]], rsample::testing(split))), + nrow(rsample::testing(split)) + ) +}) + +test_that("last_fit() can include validation set for initial_validation_split objects", { + skip_if_not_installed("modeldata") + data(ames, package = "modeldata", envir = rlang::current_env()) + + set.seed(23598723) + split <- rsample::initial_validation_split(ames) + + f <- Sale_Price ~ Gr_Liv_Area + Year_Built + train_val <- rbind(rsample::training(split), rsample::validation(split)) + lm_fit <- lm(f, data = train_val) + test_pred <- predict(lm_fit, rsample::testing(split)) + rmse_test <- yardstick::rsq_vec(rsample::testing(split) %>% pull(Sale_Price), test_pred) + + res <- parsnip::linear_reg() %>% + parsnip::set_engine("lm") %>% + last_fit(f, split, add_validation_set = TRUE) + + expect_equal(res, .Last.tune.result) + + expect_equal( + coef(extract_fit_engine(res$.workflow[[1]])), + coef(lm_fit), + ignore_attr = TRUE + ) + expect_equal(res$.metrics[[1]]$.estimate[[2]], rmse_test) + expect_equal(res$.predictions[[1]]$.pred, unname(test_pred)) + expect_true(res$.workflow[[1]]$trained) + expect_equal( + nrow(predict(res$.workflow[[1]], rsample::testing(split))), + nrow(rsample::testing(split)) + ) +}) diff --git a/tests/testthat/test-logging.R b/tests/testthat/test-logging.R index 6d202548f..1ac4bbcbf 100644 --- a/tests/testthat/test-logging.R +++ b/tests/testthat/test-logging.R @@ -82,8 +82,8 @@ test_that("catch and log issues", { }) test_that("logging iterations", { - ctrl_t <- control_grid(verbose = TRUE) - ctrl_f <- control_grid(verbose = FALSE) + ctrl_t <- control_bayes(verbose_iter = TRUE) + ctrl_f <- control_bayes(verbose_iter = FALSE) sc_1 <- list( best_val = 7, best_iter = 2, @@ -99,7 +99,7 @@ test_that("logging iterations", { }) test_that("logging search info", { - ctrl_t <- control_grid(verbose = TRUE) + ctrl_t <- control_bayes(verbose_iter = TRUE) tb_1 <- tibble::tibble(.mean = 1:3) expect_silent(tune:::check_and_log_flow(ctrl_t, tb_1)) @@ -114,8 +114,8 @@ test_that("logging search info", { }) test_that("current results", { - ctrl_t <- control_grid(verbose = TRUE) - ctrl_f <- control_grid(verbose = FALSE) + ctrl_t <- control_bayes(verbose_iter = TRUE) + ctrl_f <- control_bayes(verbose_iter = FALSE) tb_2 <- tibble::tibble( .metric = rep(letters[1:2], each = 4), @@ -141,8 +141,8 @@ test_that("current results", { test_that("show parameters", { - ctrl_t <- control_grid(verbose = TRUE) - ctrl_f <- control_grid(verbose = FALSE) + ctrl_t <- control_bayes(verbose_iter = TRUE) + ctrl_f <- control_bayes(verbose_iter = FALSE) expect_snapshot(tune:::param_msg(ctrl_t, iris[1, 4:5])) expect_silent(tune:::param_msg(ctrl_f, iris[1, 4:5])) @@ -150,8 +150,8 @@ test_that("show parameters", { test_that("acquisition functions", { - ctrl_t <- control_grid(verbose = TRUE) - ctrl_f <- control_grid(verbose = FALSE) + ctrl_t <- control_bayes(verbose_iter = TRUE) + ctrl_f <- control_bayes(verbose_iter = FALSE) expect_silent(tune:::acq_summarizer(ctrl_t, 1)) expect_silent(tune:::acq_summarizer(ctrl_t, 1, conf_bound())) diff --git a/tests/testthat/test-resample.R b/tests/testthat/test-resample.R index 54f2820b6..1d4c8827f 100644 --- a/tests/testthat/test-resample.R +++ b/tests/testthat/test-resample.R @@ -221,13 +221,12 @@ test_that("classification models generate correct error message", { expect_length(notes, 2L) # Known failure in the recipe - expect_true(all(grepl("outcome should be a factor", note))) + expect_true(all(grepl("outcome should be a `factor`", note))) expect_equal(extract, list(NULL, NULL)) expect_equal(predictions, list(NULL, NULL)) }) - # tune_grid() fallback --------------------------------------------------------- test_that("`tune_grid()` falls back to `fit_resamples()` - formula", {