Skip to content

Commit

Permalink
enable tuning postprocessors
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Nov 15, 2024
1 parent fec77ef commit bfdd585
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 33 deletions.
81 changes: 56 additions & 25 deletions R/grid_code_paths.R
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,12 @@ tune_grid_loop_iter <- function(split,

model_params <- vctrs::vec_slice(params, params$source == "model_spec")
preprocessor_params <- vctrs::vec_slice(params, params$source == "recipe")
postprocessor_params <- vctrs::vec_slice(params, params$source == "tailor")

param_names <- params$id
model_param_names <- model_params$id
preprocessor_param_names <- preprocessor_params$id
postprocessor_param_names <- postprocessor_params$id

# inline rsample::assessment so that we can pass indices to `predict_model()`
assessment_rows <- as.integer(split, data = "assessment")
Expand Down Expand Up @@ -542,34 +544,62 @@ tune_grid_loop_iter <- function(split,
# if the postprocessor does not require training, then `calibration` will
# be NULL and nothing other than the column names is learned from
# `assessment`.
workflow_with_post <- .fit_post(workflow, calibration %||% assessment)

workflow_with_post <- .fit_finalize(workflow_with_post)
# --------------------------------------------------------------------------
# Postprocessor loop
iter_postprocessors <- iter_grid_info_model[[".iter_postprocessor"]]

# run extract function on workflow with trained postprocessor
elt_extract <- .catch_and_log(
extract_details(workflow_with_post, control$extract),
control,
split_labels,
paste(iter_msg_model, "(extracts)"),
bad_only = TRUE,
notes = out_notes
)
elt_extract <- make_extracts(elt_extract, iter_grid, split_labels, .config = iter_config)
out_extracts <- append_extracts(out_extracts, elt_extract)
workflow_pre_and_model <- workflow

# generate predictions on the assessment set from the model and apply the
# post-processor to those predictions to generate updated predictions
iter_predictions <- .catch_and_log(
predict_model(assessment, assessment_rows, workflow_with_post, iter_grid,
metrics, iter_submodels, metrics_info = metrics_info,
eval_time = eval_time),
control,
split_labels,
paste(iter_msg_model, "(predictions with post-processor)"),
bad_only = TRUE,
notes = out_notes
)
for (iter_postprocessor in iter_postprocessors) {
workflow <- workflow_pre_and_model

iter_grid_info_postprocessor <- vctrs::vec_slice(
iter_grid_info_model,
iter_grid_info_model$.iter_postprocessor == iter_postprocessor
)

iter_grid_postprocessor <- iter_grid_info_postprocessor[, postprocessor_param_names]

iter_msg_postprocessor <- iter_grid_postprocessor[[".msg_postprocessor"]]
iter_config <- iter_grid_info_postprocessor[[".iter_config_post"]][[1L]]

workflow <- finalize_workflow_postprocessor(workflow, iter_grid_postprocessor)

workflow_with_post <- .fit_post(workflow, calibration %||% assessment)

workflow_with_post <- .fit_finalize(workflow_with_post)

iter_grid <- dplyr::bind_cols(
iter_grid_preprocessor,
iter_grid_model,
iter_grid_postprocessor
)

# run extract function on workflow with trained postprocessor
elt_extract <- .catch_and_log(
extract_details(workflow_with_post, control$extract),
control,
split_labels,
paste(iter_msg_model, "(extracts)"),
bad_only = TRUE,
notes = out_notes
)
elt_extract <- make_extracts(elt_extract, iter_grid, split_labels, .config = iter_config)
out_extracts <- append_extracts(out_extracts, elt_extract)

# generate predictions on the assessment set from the model and apply the
# post-processor to those predictions to generate updated predictions
iter_predictions <- .catch_and_log(
predict_model(assessment, assessment_rows, workflow_with_post, iter_grid,
metrics, iter_submodels, metrics_info = metrics_info,
eval_time = eval_time),
control,
split_labels,
paste(iter_msg_postprocessor, "(predictions with post-processor)"),
bad_only = TRUE,
notes = out_notes
)

# now, assess those predictions with performance metrics
}
Expand All @@ -595,6 +625,7 @@ tune_grid_loop_iter <- function(split,
control = control,
.config = iter_config_metrics
)
} # postprocessor loop
} # model loop
} # preprocessor loop

Expand Down
97 changes: 90 additions & 7 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

predict_model <- function(new_data, orig_rows, workflow, grid, metrics,
submodels = NULL, metrics_info, eval_time = NULL) {

model <- extract_fit_parsnip(workflow)

forged <- forge_from_workflow(new_data, workflow)
Expand Down Expand Up @@ -260,6 +259,22 @@ finalize_workflow_preprocessor <- function(workflow, grid_preprocessor) {
workflow
}

#' @export
#' @rdname tune-internal-functions
finalize_workflow_postprocessor <- function(workflow, grid_postprocessor) {
# Already finalized, nothing to tune
if (ncol(grid_postprocessor) == 0L) {
return(workflow)
}

postprocessor <- workflows::extract_postprocessor(workflow)
postprocessor <- merge(postprocessor, grid_postprocessor)$x[[1]]

workflow <- set_workflow_tailor(workflow, postprocessor)

workflow
}

# ------------------------------------------------------------------------------

# For any type of tuning, and for fit-resamples, we generate a unified
Expand Down Expand Up @@ -310,16 +325,20 @@ compute_grid_info <- function(workflow, grid) {
grid <- tibble::as_tibble(grid)

parameters <- hardhat::extract_parameter_set_dials(workflow)
parameters_model <- dplyr::filter(parameters, source == "model_spec")

parameters_preprocessor <- dplyr::filter(parameters, source == "recipe")
parameters_model <- dplyr::filter(parameters, source == "model_spec")
parameters_postprocessor <- dplyr::filter(parameters, source == "tailor")

any_parameters_model <- nrow(parameters_model) > 0
any_parameters_preprocessor <- nrow(parameters_preprocessor) > 0

res <- min_grid(extract_spec_parsnip(workflow), grid)
any_parameters_model <- nrow(parameters_model) > 0
any_parameters_postprocessor <- nrow(parameters_postprocessor) > 0

syms_pre <- rlang::syms(parameters_preprocessor$id)
syms_mod <- rlang::syms(parameters_model$id)
syms_post <- rlang::syms(parameters_postprocessor$id)

res <- min_grid(extract_spec_parsnip(workflow), grid)

# ----------------------------------------------------------------------------
# Create an order of execution to train the preprocessor (if any). This will
Expand All @@ -340,7 +359,7 @@ compute_grid_info <- function(workflow, grid) {
res$.lab_pre <- "Preprocessor1"
}

# Make the label shown in the grid and in loggining
# Make the label shown in the grid and in logging
res$.msg_preprocessor <-
new_msgs_preprocessor(
res$.iter_preprocessor,
Expand All @@ -351,6 +370,17 @@ compute_grid_info <- function(workflow, grid) {
# Now make a similar iterator across models. Conditioning on each unique
# preprocessing candidate set, make an iterator for the model candidate sets
# (if any)
if (any_parameters_postprocessor) {
# Ensure that the submodel trick kicks in by temporarily nesting the
# postprocessor parameters while iterating in the model grid
# TODO: will this introduce issues when there are matching postprocessor
# values across models?
# ... i think we actually want to (temporarily?) situate these as submodels
res <- tidyr::nest(
res,
.data_post = all_of(parameters_postprocessor$id)
)
}

res <-
res %>%
Expand All @@ -370,9 +400,28 @@ compute_grid_info <- function(workflow, grid) {
n = res$.num_models,
res$.msg_preprocessor)

res %>%
res <- res %>%
dplyr::select(-.num_models) %>%
dplyr::relocate(dplyr::starts_with(".msg"))

# ----------------------------------------------------------------------------
# Finally, iterate across postprocessors. Conditioning on an .iter_config,
# make an iterator for each postprocessing candidate set (if any).
if (!any_parameters_postprocessor) {
return(res)
}

res <-
res %>%
dplyr::group_nest(.iter_config, keep = TRUE) %>%
dplyr::mutate(
data = purrr::map(data, make_iter_postprocessor)
) %>%
tidyr::unnest(cols = data) %>%
dplyr::relocate(dplyr::starts_with(".iter"), dplyr::starts_with(".msg")) %>%
tidyr::unnest(.data_post)

res
}

make_iter_config <- function(dat) {
Expand All @@ -385,6 +434,32 @@ make_iter_config <- function(dat) {
tibble::tibble(.iter_config = .iter_config)
}

make_iter_postprocessor <- function(data) {
data %>%
mutate(
.iter_postprocessor = seq_len(nrow(data)),
.msg_postprocessor = new_msgs_postprocessor(
i = .iter_postprocessor,
n = max(.iter_postprocessor),
msgs_model = .msg_model
),
.iter_config_post = purrr::map2(
.iter_config,
.iter_postprocessor,
make_iter_config_post
)
) %>%
select(-.iter_config)
}

make_iter_config_post <- function(iter_config, iter_postprocessor) {
paste0(
iter_config,
"_Postprocessor",
iter_postprocessor
)
}

# This generates a "dummy" grid_info object that has the same
# structure as a grid-info object with no tunable recipe parameters
# and no tunable model parameters.
Expand Down Expand Up @@ -420,6 +495,9 @@ new_msgs_preprocessor <- function(i, n) {
new_msgs_model <- function(i, n, msgs_preprocessor) {
paste0(msgs_preprocessor, ", model ", i, "/", n)
}
new_msgs_postprocessor <- function(i, n, msgs_model) {
paste0(msgs_model, ", postprocessor ", i, "/", n)
}

# c(1, 10) -> c("01", "10")
format_with_padding <- function(x) {
Expand Down Expand Up @@ -467,3 +545,8 @@ set_workflow_recipe <- function(workflow, recipe) {
workflow$pre$actions$recipe$recipe <- recipe
workflow
}

set_workflow_tailor <- function(workflow, tailor) {
workflow$post$actions$tailor$tailor <- tailor
workflow
}
21 changes: 20 additions & 1 deletion R/merge.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ merge.model_spec <- function(x, y, ...) {
merger(x, y, ...)
}

#' @export
#' @rdname merge.recipe
merge.tailor <- function(x, y, ...) {
merger(x, y, ...)
}

update_model <- function(grid, object, pset, step_id, nms, ...) {
for (i in nms) {
param_info <- pset %>% dplyr::filter(id == i & source == "model_spec")
Expand Down Expand Up @@ -108,6 +114,16 @@ update_recipe <- function(grid, object, pset, step_id, nms, ...) {
object
}

update_tailor <- function(grid, object, pset, adjustment_id, nms, ...) {
for (i in nms) {
param_info <- pset %>% dplyr::filter(id == i & source == "tailor")
if (nrow(param_info) == 1) {
idx <- which(adjustment_id == param_info$component_id)
object$adjustments[[idx]][["arguments"]][[param_info$name]] <- grid[[i]]
}
}
object
}

merger <- function(x, y, ...) {
if (!is.data.frame(y)) {
Expand All @@ -127,9 +143,12 @@ merger <- function(x, y, ...) {
if (inherits(x, "recipe")) {
updater <- update_recipe
step_ids <- purrr::map_chr(x$steps, ~ .x$id)
} else {
} else if (inherits(x, "model_spec")) {
updater <- update_model
step_ids <- NULL
} else {
updater <- update_tailor
step_ids <- purrr::map_chr(x$adjustments, ~class(.x)[1])
}

if (!any(grid_name %in% pset$id)) {
Expand Down
10 changes: 10 additions & 0 deletions R/min_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,13 @@ min_grid.pls <- fit_max_value
#' @export min_grid.poisson_reg
#' @rdname min_grid
min_grid.poisson_reg <- fit_max_value


# When `min_grid()` is applied to grids with additional columns for
# postprocessors, we need to nest the postprocessor columns into
# .submodels to effectively enable the submodel trick.
# See: https://gist.github.com/simonpcouch/28d984cdcc3fc6d22ff776ed8740004e
nest_min_grid <- function(min_grid, post_params) {
# TODO
min_grid
}
Loading

0 comments on commit bfdd585

Please sign in to comment.