Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RC 1.1.2 #707

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tune
Title: Tidy Tuning Tools
Version: 1.1.1
Version: 1.1.2
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
Expand Down Expand Up @@ -31,7 +31,7 @@ Imports:
purrr (>= 1.0.0),
recipes (>= 1.0.4),
rlang (>= 1.0.2),
rsample (>= 1.0.0),
rsample (>= 1.1.1.9001),
tibble (>= 3.1.0),
tidyr (>= 1.2.0),
tidyselect (>= 1.1.2),
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# tune 1.1.2

* `last_fit()` now works with the 3-way validation split objects from `rsample::initial_validation_split()`. `last_fit()` and `fit_best()` now have a new argument `add_validation_set` to include or exclude the validation set in the dataset used to fit the model (#701).

* Disambiguates the `verbose` and `verbose_iter` control options to better align with documented functionality. The former controls logging for general progress updates, while the latter only does so for the Bayesian search process. (#682)

# tune 1.1.1

* Fixed a bug introduced in tune 1.1.0 in `collect_()` functions where the
Expand Down
35 changes: 30 additions & 5 deletions R/fit_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
#' If `NULL`, this argument will be set to
#' [`select_best(metric)`][tune::select_best.tune_results].
#' @param verbose A logical for printing logging.
#' @param add_validation_set When the resamples embedded in `x` are a split into
#' training set and validation set, should the validation set be included in the
#' data set used to train the model? If not, only the training set is used. If
#' `NULL`, the validation set is not used for resamples originating from
#' [rsample::validation_set()] while it is used for resamples originating
#' from [rsample::validation_split()].
#' @param ... Not currently used.
#' @details
#' This function is a shortcut for the manual steps of:
Expand All @@ -24,10 +30,6 @@
#' wflow_fit <- fit(wflow, data_set)
#' }
#'
#' The data used for the fit are taken from the `splits` column. If the split
#' column was from a validation split, the combined training and validation sets
#' are used.
#'
#' In comparison to [last_fit()], that function requires a finalized model, fits
#' the model on the training set defined by [rsample::initial_split()], and
#' computes metrics from the test set.
Expand Down Expand Up @@ -85,6 +87,7 @@ fit_best.tune_results <- function(x,
metric = NULL,
parameters = NULL,
verbose = FALSE,
add_validation_set = NULL,
...) {
if (length(list(...))) {
cli::cli_abort(c("x" = "The `...` are not used by this function."))
Expand Down Expand Up @@ -121,7 +124,29 @@ fit_best.tune_results <- function(x,

# ----------------------------------------------------------------------------

dat <- x$splits[[1]]$data
if (inherits(x$splits[[1]], "val_split")) {
if (is.null(add_validation_set)) {
rset_info <- attr(x, "rset_info")
originate_from_3way_split <- rset_info$att$origin_3way %||% FALSE
if (originate_from_3way_split) {
add_validation_set <- FALSE
} else {
add_validation_set <- TRUE
}
}
if (add_validation_set) {
dat <- x$splits[[1]]$data
} else {
dat <- rsample::training(x$splits[[1]])
}
} else {
if (!is.null(add_validation_set)) {
rlang::warn(
"The option `add_validation_set` is being ignored because the resampling object does not include a validation set."
)
}
dat <- x$splits[[1]]$data
}
if (verbose) {
cli::cli_inform(c("i" = "Fitting using {nrow(dat)} data points..."))
}
Expand Down
4 changes: 4 additions & 0 deletions R/grid_performance.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ metrics_info <- function(x) {
#' @param new_data A data frame or matrix of predictors to process.
#' @param metrics_info The output of `tune:::metrics_info(metrics)`---only
#' included as an argument to allow for pre-computing.
#' @param catalog A logical passed to `tune_log()` giving whether the message
#' is compatible with the issue cataloger. Defaults to `TRUE`. Updates that are
#' always unique and do not represent a tuning "issue" can bypass the cataloger
#' by setting `catalog = FALSE`.
#' @keywords internal
#' @name tune-internal-functions
#' @export
Expand Down
54 changes: 48 additions & 6 deletions R/last_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
#' @param preprocessor A traditional model formula or a recipe created using
#' [recipes::recipe()].
#'
#' @param split An `rsplit` object created from [rsample::initial_split()].
#' @param split An `rsplit` object created from [rsample::initial_split()] or
#' [rsample::initial_validation_split()].
#'
#' @param metrics A [yardstick::metric_set()], or `NULL` to compute a standard
#' set of metrics.
#'
#' @param control A [control_last_fit()] object used to fine tune the last fit
#' process.
#'
#' @param add_validation_set For 3-way splits into training, validation, and test
#' set via [rsample::initial_validation_split()], should the validation set be
#' included in the data set used to train the model. If not, only the training
#' set is used.
#'
#' @param ... Currently unused.
#'
#' @details
Expand Down Expand Up @@ -73,7 +79,9 @@ last_fit.default <- function(object, ...) {

#' @export
#' @rdname last_fit
last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL, control = control_last_fit()) {
last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL,
control = control_last_fit(),
add_validation_set = FALSE) {
if (rlang::is_missing(preprocessor) || !is_preprocessor(preprocessor)) {
rlang::abort(paste(
"To tune a model spec, you must preprocess",
Expand All @@ -93,22 +101,27 @@ last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL
wflow <- add_formula(wflow, preprocessor)
}

last_fit_workflow(wflow, split, metrics, control)
last_fit_workflow(wflow, split, metrics, control, add_validation_set)
}


#' @rdname last_fit
#' @export
last_fit.workflow <- function(object, split, ..., metrics = NULL, control = control_last_fit()) {
last_fit.workflow <- function(object, split, ..., metrics = NULL,
control = control_last_fit(),
add_validation_set = FALSE) {
empty_ellipses(...)

control <- parsnip::condense_control(control, control_last_fit())

last_fit_workflow(object, split, metrics, control)
last_fit_workflow(object, split, metrics, control, add_validation_set)
}

last_fit_workflow <- function(object, split, metrics, control) {
last_fit_workflow <- function(object, split, metrics, control, add_validation_set) {
check_no_tuning(object)
if (inherits(split, "initial_validation_split")) {
split <- prepare_validation_split(split, add_validation_set)
}
splits <- list(split)
resamples <- rsample::manual_rset(splits, ids = "train/test split")

Expand All @@ -132,3 +145,32 @@ last_fit_workflow <- function(object, split, metrics, control) {
.stash_last_result(res)
res
}


prepare_validation_split <- function(split, add_validation_set){
if (add_validation_set) {
# equivalent to (unexported) rsample:::rsplit() without checks
split <- structure(
list(
data = split$data,
in_id = c(split$train_id, split$val_id),
out_id = NA
),
class = "rsplit"
)
} else {
id_train_test <- seq_len(nrow(split$data))[-sort(split$val_id)]
id_train <- match(split$train_id, id_train_test)

split <- structure(
list(
data = split$data[-sort(split$val_id), , drop = FALSE],
in_id = id_train,
out_id = NA
),
class = "rsplit"
)
}

split
}
33 changes: 20 additions & 13 deletions R/logging.R
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,12 @@ siren <- function(x, type = "info") {
}


tune_log <- function(control, split = NULL, task, type = "success") {
if (!control$verbose) {
tune_log <- function(control, split = NULL, task, type = "success", catalog = TRUE) {
if (!any(control$verbose, control$verbose_iter)) {
return(invisible(NULL))
}

if (uses_catalog()) {
if (uses_catalog() & catalog) {
log_catalog(task, type)
return(NULL)
}
Expand Down Expand Up @@ -291,6 +291,7 @@ log_problems <- function(notes, control, split, loc, res, bad_only = FALSE) {
# Always log warnings and errors
control2 <- control
control2$verbose <- TRUE
control2$verbose_iter <- TRUE

should_catalog <- uses_catalog()

Expand Down Expand Up @@ -361,8 +362,8 @@ format_msg <- function(loc, msg) {

#' @export
#' @rdname tune-internal-functions
.catch_and_log <- function(.expr, ..., bad_only = FALSE, notes) {
tune_log(..., type = "info")
.catch_and_log <- function(.expr, ..., bad_only = FALSE, notes, catalog = TRUE) {
tune_log(..., type = "info", catalog = catalog)
tmp <- catcher(.expr)
new_notes <- log_problems(notes, ..., tmp, bad_only = bad_only)
assign("out_notes", new_notes, envir = parent.frame())
Expand Down Expand Up @@ -410,7 +411,7 @@ format_msg <- function(loc, msg) {
}

log_best <- function(control, iter, info, digits = 4) {
if (!control$verbose) {
if (!isTRUE(control$verbose_iter)) {
return(invisible(NULL))
}

Expand All @@ -428,24 +429,30 @@ log_best <- function(control, iter, info, digits = 4) {
info$best_iter,
")"
)
tune_log(control, split = NULL, task = msg, type = "info")
tune_log(control, split = NULL, task = msg, type = "info", catalog = FALSE)
}

check_and_log_flow <- function(control, results) {
if (!isTRUE(control$verbose_iter)) {
return(invisible(NULL))
}

if (all(is.na(results$.mean))) {
if (nrow(results) < 2) {
tune_log(control, split = NULL, task = "Halting search", type = "danger")
tune_log(control, split = NULL, task = "Halting search",
type = "danger", catalog = FALSE)
eval.parent(parse(text = "break"))
} else {
tune_log(control, split = NULL, task = "Skipping to next iteration", type = "danger")
tune_log(control, split = NULL, task = "Skipping to next iteration",
type = "danger", catalog = FALSE)
eval.parent(parse(text = "next"))
}
}
invisible(NULL)
}

log_progress <- function(control, x, maximize = TRUE, objective = NULL, digits = 4) {
if (!control$verbose) {
if (!isTRUE(control$verbose_iter)) {
return(invisible(NULL))
}

Expand Down Expand Up @@ -479,7 +486,7 @@ log_progress <- function(control, x, maximize = TRUE, objective = NULL, digits =
}

param_msg <- function(control, candidate) {
if (!control$verbose) {
if (!isTRUE(control$verbose_iter)) {
return(invisible(NULL))
}
candidate <- candidate[, !(names(candidate) %in% c(".mean", ".sd", "objective"))]
Expand All @@ -495,7 +502,7 @@ param_msg <- function(control, candidate) {


acq_summarizer <- function(control, iter, objective = NULL, digits = 4) {
if (!control$verbose) {
if (!isTRUE(control$verbose_iter)) {
return(invisible(NULL))
}
if (inherits(objective, "conf_bound") && is.function(objective$kappa)) {
Expand All @@ -509,7 +516,7 @@ acq_summarizer <- function(control, iter, objective = NULL, digits = 4) {
}
}
if (!is.null(val)) {
tune_log(control, split = NULL, task = val, type = "info")
tune_log(control, split = NULL, task = val, type = "info", catalog = FALSE)
}
invisible(NULL)
}
8 changes: 5 additions & 3 deletions R/tune_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,8 @@ tune_bayes_workflow <-
control,
NULL,
"Gaussian process model",
notes = .notes
notes = .notes,
catalog = FALSE
)

gp_mod <- check_gp_failure(gp_mod, prev_gp_mod)
Expand Down Expand Up @@ -563,13 +564,14 @@ pred_gp <- function(object, pset, size = 5000, current = NULL, control) {
control,
split = NULL,
task = paste("Generating", nrow(pred_grid), "candidates"),
type = "info"
type = "info",
catalog = FALSE
)

x <- encode_set(pred_grid, pset, as_matrix = TRUE)
gp_pred <- predict(object, x)

tune_log(control, split = NULL, task = "Predicted candidates", type = "info")
tune_log(control, split = NULL, task = "Predicted candidates", type = "info", catalog = FALSE)

pred_grid %>%
dplyr::mutate(.mean = gp_pred$Y_hat, .sd = sqrt(gp_pred$MSE))
Expand Down
20 changes: 15 additions & 5 deletions man/fit_best.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading