Skip to content

Commit

Permalink
update docs, add_vars
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewGhazi committed Jul 16, 2024
1 parent 259f71f commit 5e82393
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
33 changes: 15 additions & 18 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,20 @@ run_gp = function(dat, ..., max_grid_size = 2000,
#' @param lambda tradeoff between weighting posterior predictive variance and expected
#' improvement at grid values.
#' @details The acquisition function is \code{lambda*f_star_var + (1-lambda)*exp_imp}.
#' Higher values of lambda up-weight posterior predictive variance, leading to more
#' exploration over exploitation. Lower lambda values up-weight expected improvement over \code{max(dat$rating) - offset}.
#' Higher values of lambda up-weight posterior predictive variance, leading to more
#' exploration over exploitation. Lower lambda values up-weight expected improvement
#' over \code{max(dat$rating) - offset}.
#'
#' For the sake of simplicity, the range of each parameter in the grid is linearly
#' scaled to an N-dimensional hypercube that spans -3 to 3 on each edge. So the model is
#' insensitive to the range of grid values. It won't make a difference if your grinder
#' shows different numbers or something.
#'
#' It's normal for the sampling to slow down dramatically after the warmup phase. This
#' is because while it's fast to fit a GP to a tiny number of observations, it's much
#' more expensive to evaluate the GP over the parameter grid. This happens in the
#' generated quantities block of the model, which only gets evaluated in the sampling
#' phase, which is why it's slow. Whatever man. My CPU has 16 physical cores.
#' @returns a list with elements:
#' \itemize{
#' \item{draws_df}{a draws data frame of model parameters and grid point predictive draws f_star}
Expand Down Expand Up @@ -184,21 +196,6 @@ suggest_next = function(dat, ..., max_grid_size = 2000,

if (all(exp_imp < .Machine$double.eps^0.5)) cli::cli_warn("All expected improvement values near zero. You may need to run the chains for longer or raise {.var offset}.")

# pred_g = x_grid[max_pred_dens,,drop=FALSE][,"gc"]

# exp_imp_post = data.table(variable = colnames(exp_imp),
# mean = exp_imp |> colMeans(),
# i = 1:ncol(exp_imp))
#
# post_range = exp_imp_post$mean |> range()

# qDT(x_grid) |> mtt(i = 1:nrow(x_grid)) |>
# sbt(dplyr::near(gc, pred_g)) |>
# join(exp_imp_post, on = "i", validate = "1:1") |>
# ggplot(aes(tc, bc)) +
# geom_tile(aes(fill = mean)) +
# scale_fill_viridis_c(limits = post_range)

# posterior uncertainty ----
f_mean_sd = f_mean_mat |> fsd()

Expand All @@ -209,7 +206,7 @@ suggest_next = function(dat, ..., max_grid_size = 2000,
acq_df = data.table(post_sd = f_mean_sd,
exp_imp = exp_imp,
acq = combined_acq) |>
cbind(x_grid)
add_vars(x_grid)

suggest = acq_df |> sbt(whichv(acq, fmax(acq)))

Expand Down
16 changes: 14 additions & 2 deletions man/suggest_next.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5e82393

Please sign in to comment.