From dd0f0d1b6a695bd5c3dd72f0dde43eb05cd7a594 Mon Sep 17 00:00:00 2001 From: Andrew Ghazi <6763470+andrewGhazi@users.noreply.github.com> Date: Sun, 30 Jun 2024 16:30:26 -0400 Subject: [PATCH] allow variable brew parameters --- NAMESPACE | 1 + R/check_input.R | 38 ++++++++- R/run.R | 178 ++++++++++++++++++++++++++++++++----------- README.Rmd | 12 +-- README.md | 1 - man/check_df.Rd | 5 ++ man/create_ranges.Rd | 13 ++++ man/run_gp.Rd | 29 +++++++ man/suggest_next.Rd | 30 ++++++++ 9 files changed, 255 insertions(+), 52 deletions(-) create mode 100644 man/create_ranges.Rd create mode 100644 man/run_gp.Rd create mode 100644 man/suggest_next.Rd diff --git a/NAMESPACE b/NAMESPACE index f8e954e..a635285 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,6 @@ # Generated by roxygen2: do not edit by hand +export(create_ranges) export(run_gp) export(suggest_next) import(collapse) diff --git a/R/check_input.R b/R/check_input.R index ebedb69..7d19a02 100644 --- a/R/check_input.R +++ b/R/check_input.R @@ -1,7 +1,17 @@ +get_params = function(dat) { + grep("rating", names(dat), value = TRUE, invert = TRUE) +} + +all_within = function(param_vec, range_vec) { + all((param_vec >= range_vec[1]) & (param_vec <= range_vec[2])) +} + + #' Check input data.frame #' @description #' The input data frame should have a limited number of columns and at least two rows -#' +#' @param dat data frame input +#' @param call calling environment check_df = function(dat, call = rlang::caller_env()) { cn = colnames(dat) @@ -11,3 +21,29 @@ check_df = function(dat, call = rlang::caller_env()) { # if (nrow(dat) < 2) cli::cli_abort("Input needs at least two existing observations.", call = call) } + +check_param_olap = function(dat, param_ranges) { + dat_params = get_params(dat) + rng_params = get_params(param_ranges) + + param_int = intersect(dat_params, rng_params) + + all_in_int = all(dat_params %in% param_int) & all(rng_params %in% param_int) + + if (!all_in_int) cli::cli_warn("Non-overlapping columns between data and parameter ranges will be dropped.") + +} + +check_ranges = function(dat, param_ranges, call = rlang::caller_env()) { + params = get_params(dat) + + check_param_olap(dat, param_ranges) + + within_ranges = mapply(all_within, + dat |> get_vars(params), param_ranges |> get_vars(params)) + + if (!all(within_ranges)) cli::cli_abort("Provided parameter values fall outside the specified ranges. Those with values outside the provided ranges are {.val {names(within_ranges[!within_ranges])}}", + call = call) + + list(dat |> get_vars(c(params, "rating")), param_ranges |> get_vars(params)) +} diff --git a/R/run.R b/R/run.R index 7dd068c..6c31c42 100644 --- a/R/run.R +++ b/R/run.R @@ -1,41 +1,143 @@ +get_grid_vec = function(param_id, param_range) { + if (param_id == "grinder_setting") { + res = seq(param_range[1], param_range[2], by = .5) + } else if (param_id == "temp") { + res = seq(param_range[1], param_range[2], by = 5) + } else if (param_id == "bloom_time") { + res = seq(param_range[1], param_range[2], by = 10) + } else { + res = seq(param_range[1], param_range[2], length.out = 6) + } + + res +} + +form_x_grid = function(max_grid_size, + param_ranges) { + + params = get_params(param_ranges) + + vec_list = mapply(get_grid_vec, + params, param_ranges, + SIMPLIFY = FALSE) + + res = expand.grid(vec_list) |> qDT() + + if (nrow(res) > max_grid_size) cli::cli_abort("Automated grid exceeded the specified {.var max_grid_size}. Either provide your own grid or increase {.var max_grid_size}") + + res +} + +get_x_grid = function(max_grid_size, + param_ranges, param_grid) { + + if (!is.null(param_grid)) { + x_grid = param_grid + } else { + x_grid = form_x_grid(max_grid_size, + param_ranges) + } + + x_grid +} + + +#' Create a range data frame +#' @description This function creates an example data frame of mins and maxs for brew +#' parameter settings. That is, the range of grinder settings I want to search is from 4 +#' to 14, temperatures from 170 to 210F, and bloom times from 0 to 60s. +#' #' @export -run_gp = function(dat, ...) { +create_ranges = function() { + data.frame(grinder_setting = c(4,14), + temp = c(170, 210), + bloom_time = c(0, 60)) +} + +get_centers_and_widths = function(param_ranges) { + centers = param_ranges |> sapply(fmean) + widths = (param_ranges |> sapply(diff)) / 2 + list(centers, widths) +} + +center_grid = function(x_grid, param_ranges) { + cents_widths = get_centers_and_widths(param_ranges) - check_df(dat) + centers = cents_widths[[1]] + widths = cents_widths[[2]] + + res = x_grid |> + TRA(centers) |> + TRA(widths, FUN = "/") |> + TRA(rep(3, ncol(x_grid)), FUN = "*") + + names(res) = paste0(names(res), "_cent") + res +} + +center_dat = function(dat, param_ranges) { + cents_widths = get_centers_and_widths(param_ranges) - # TODO adapt centering/scaling, generalize to arbitrary # of parameters - dat = dat |> - mtt(gs_cent = (grinder_setting - 9) / 5 * 3, - temp_cent = (temp - 190) / (20) * 3, - bloom_cent = (bloom_time - 30) / 30 * 3) |> - qDT() + centers = cents_widths[[1]] + widths = cents_widths[[2]] - g_map = data.table(g = seq(4,14, by = .5)) |> - mtt(gc = (g - 9) / 5 * 3) + params = get_params(dat) + + res = dat |> + get_vars(params) |> + TRA(centers) |> + TRA(widths, FUN = "/") |> + TRA(rep(3, length(params)), FUN = "*") + + names(res) = paste0(names(res), "_cent") + + res |> + add_vars(dat$rating) +} + +#' Run the GP +#' @param dat data frame input of brew parameters and rating +#' @param ... arguments passed to cmdstanr's sample method +#' @param max_grid_size maximum number of grid points to evaluate +#' @param param_ranges upper and lower limits of parameter ranges to evaluate +#' @details +#' The function \code{\link{create_ranges()}} will create an example range df. +#' +#' @export +run_gp = function(dat, ..., max_grid_size = 2000, + param_ranges = create_ranges(), param_grid = NULL) { - t_map = data.table(t = seq(170, 210, by = 5), - tc = (seq(170, 210, by = 5) - 190) / 20 * 3) + check_df(dat) + cr_res = check_ranges(dat, param_ranges) + dat = cr_res[[1]]; param_ranges = cr_res[[2]] - b_map = data.table(b = seq(0, 60, by = 10), - bc = ((seq(0, 60, by = 10) - 30) / 30) * 3 ) + x_grid = get_x_grid(max_grid_size, param_ranges, param_grid) + x_grid_cent = center_grid(x_grid, param_ranges) - x_grid = expand.grid(gc = g_map$gc, - tc = t_map$tc, - bc = b_map$bc) |> - qM() + centered_dat = center_dat(dat, param_ranges) - X = dat |> slt(gs_cent, temp_cent, bloom_cent) |> qM() + X = centered_dat |> get_vars("_cent", regex=TRUE) |> qM() - list(run_gp_model(X, dat$rating, x_grid, ...), - x_grid) + list(run_gp_model(X = X, y = dat$rating, X_pred = x_grid_cent, ...), + x_grid, + x_grid_cent) } +#' Suggest the next point to try +#' @inheritParams run_gp +#' @param ... arguments passed to cmdstanr's sample method +#' @param offset expected improvement hyperparameter. Higher values encourage more +#' exploration. Interpreted on the same scale as ratings. #' @export -suggest_next = function(dat, x_grid, ...) { +suggest_next = function(dat, ..., max_grid_size = 2000, + param_ranges = create_ranges(), param_grid = NULL, + offset = .25) { run_res = run_gp(dat, ...) + gp_res = run_res[[1]] x_grid = run_res[[2]] + x_grid_cent = run_res[[3]] obs_max = max(dat$rating) @@ -48,15 +150,15 @@ suggest_next = function(dat, x_grid, ...) { max_pred_dens = fsum(acq) |> which.max() - if (max_pred_dens == 1) cli::cli_warn("Selected the first grid point as maximum of the acquisition function. You may need to run the chains for longer.") - - pred_g = x_grid[max_pred_dens,,drop=FALSE][,"gc"] + if (max_pred_dens == 1) cli::cli_warn("Selected the first grid point as maximum of the acquisition function. You may need to run the chains for longer or lower {.var offset}.") - acq_post = data.table(variable = colnames(acq), - mean = acq |> colMeans(), - i = 1:ncol(acq)) + # pred_g = x_grid[max_pred_dens,,drop=FALSE][,"gc"] - post_range = acq_post$mean |> range() + # acq_post = data.table(variable = colnames(acq), + # mean = acq |> colMeans(), + # i = 1:ncol(acq)) + # + # post_range = acq_post$mean |> range() # qDT(x_grid) |> mtt(i = 1:nrow(x_grid)) |> # sbt(dplyr::near(gc, pred_g)) |> @@ -65,19 +167,7 @@ suggest_next = function(dat, x_grid, ...) { # geom_tile(aes(fill = mean)) + # scale_fill_viridis_c(limits = post_range) - g_map = data.table(g = seq(4,14, by = .5)) |> - mtt(gc = (g - 9) / 5 * 3) - - t_map = data.table(t = seq(170, 210, by = 5), - tc = (seq(170, 210, by = 5) - 190) / 20 * 3) - - b_map = data.table(b = seq(0, 60, by = 10), - bc = ((seq(0, 60, by = 10) - 30) / 30) * 3 ) - - x_grid[max_pred_dens,,drop=FALSE] |> - qDT() |> - join(g_map, verbose = FALSE) |> - join(t_map, verbose = FALSE) |> - join(b_map, verbose = FALSE) - + list(draws_df = gp_res, + x_grid = x_grid, + suggested = x_grid[max_pred_dens,] ) } diff --git a/README.Rmd b/README.Rmd index 70394c3..eef93f6 100644 --- a/README.Rmd +++ b/README.Rmd @@ -43,16 +43,16 @@ Give the `suggest_next()` function a data frame of brew parameters with ratings ```{r eval=FALSE} library(dyingforacup) +options(mc.cores = 4) - -dat = data.frame(grinder_setting = c(8, 193, 25), - temp = c(7, 195, 20), - bloom_time = c(9, 179, 45), - rating = c(1.1, -.7, -1)) +dat = data.frame(grinder_setting = c( 8, 7, 9), + temp = c(193, 195, 179), + bloom_time = c( 25, 20, 45), + rating = c(1.1, -0.7, -1)) suggest_next(dat, iter_sampling = 4000, - refresh = 1250, + refresh = 0, show_exceptions = FALSE, adapt_delta = .95, parallel_chains = 4) diff --git a/README.md b/README.md index a821076..0e7fbaa 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,6 @@ probability of improving the rating. ``` r library(dyingforacup) - dat = data.frame(grinder_setting = c(8, 193, 25), temp = c(7, 195, 20), bloom_time = c(9, 179, 45), diff --git a/man/check_df.Rd b/man/check_df.Rd index e151bd6..b1d0ff3 100644 --- a/man/check_df.Rd +++ b/man/check_df.Rd @@ -6,6 +6,11 @@ \usage{ check_df(dat, call = rlang::caller_env()) } +\arguments{ +\item{dat}{data frame input} + +\item{call}{calling environment} +} \description{ The input data frame should have a limited number of columns and at least two rows } diff --git a/man/create_ranges.Rd b/man/create_ranges.Rd new file mode 100644 index 0000000..6c39360 --- /dev/null +++ b/man/create_ranges.Rd @@ -0,0 +1,13 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/run.R +\name{create_ranges} +\alias{create_ranges} +\title{Create a range data frame} +\usage{ +create_ranges() +} +\description{ +This function creates an example data frame of mins and maxs for brew + parameter settings. That is, the range of grinder settings I want to search is from 4 + to 14, temperatures from 170 to 210F, and bloom times from 0 to 60s. +} diff --git a/man/run_gp.Rd b/man/run_gp.Rd new file mode 100644 index 0000000..e671e40 --- /dev/null +++ b/man/run_gp.Rd @@ -0,0 +1,29 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/run.R +\name{run_gp} +\alias{run_gp} +\title{Run the GP} +\usage{ +run_gp( + dat, + ..., + max_grid_size = 2000, + param_ranges = create_ranges(), + param_grid = NULL +) +} +\arguments{ +\item{dat}{data frame input of brew parameters and rating} + +\item{...}{arguments passed to cmdstanr's sample method} + +\item{max_grid_size}{maximum number of grid points to evaluate} + +\item{param_ranges}{upper and lower limits of parameter ranges to evaluate} +} +\description{ +Run the GP +} +\details{ +The function \code{\link{create_ranges()}} will create an example range df. +} diff --git a/man/suggest_next.Rd b/man/suggest_next.Rd new file mode 100644 index 0000000..de22146 --- /dev/null +++ b/man/suggest_next.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/run.R +\name{suggest_next} +\alias{suggest_next} +\title{Suggest the next point to try} +\usage{ +suggest_next( + dat, + ..., + max_grid_size = 2000, + param_ranges = create_ranges(), + param_grid = NULL, + offset = 0.25 +) +} +\arguments{ +\item{dat}{data frame input of brew parameters and rating} + +\item{...}{arguments passed to cmdstanr's sample method} + +\item{max_grid_size}{maximum number of grid points to evaluate} + +\item{param_ranges}{upper and lower limits of parameter ranges to evaluate} + +\item{offset}{expected improvement hyperparameter. Higher values encourage more +exploration. Interpreted on the same scale as ratings.} +} +\description{ +Suggest the next point to try +}