Skip to content

Commit

Permalink
fix error in generated quantities
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewGhazi committed Jul 8, 2024
1 parent 10685f5 commit ab7b41d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
18 changes: 11 additions & 7 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,21 @@ run_gp = function(dat, ..., max_grid_size = 2000,

check_df(dat)

if (!is.null(param_grid)) {
if (is.null(param_grid)) {
cr_res = check_ranges(dat, param_ranges)
dat = cr_res[[1]]
param_ranges = cr_res[[2]]
x_grid = get_x_grid(max_grid_size, param_ranges, param_grid)
x_grid_cent = center_grid(x_grid, param_ranges)
centered_dat = center_dat(dat, param_ranges)
} else {
# TODO: make sure this handles uneven user-provided grids correctly.
emp_ranges = param_grid |> lapply(frange) |> qDT()
x_grid = param_grid
x_grid_cent = param_grid |> center_grid(emp_ranges)
centered_dat = center_dat(dat, emp_ranges)
}

x_grid = get_x_grid(max_grid_size, param_ranges, param_grid)
x_grid_cent = center_grid(x_grid, param_ranges)

centered_dat = center_dat(dat, param_ranges)

X = centered_dat |> get_vars("_cent", regex = TRUE) |> qM()

list(run_gp_model(X = X, y = dat$rating, X_pred = x_grid_cent, ...),
Expand Down Expand Up @@ -153,7 +157,7 @@ run_gp = function(dat, ..., max_grid_size = 2000,
suggest_next = function(dat, ..., max_grid_size = 2000,
param_ranges = create_ranges(), param_grid = NULL,
offset = .25,
lambda = .01) {
lambda = .1) {

run_res = run_gp(dat,
max_grid_size = max_grid_size,
Expand Down
2 changes: 1 addition & 1 deletion man/suggest_next.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 11 additions & 4 deletions src/stan/gp_mod.stan
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ transformed parameters {
}

model {
// rho ~ inv_gamma(3,3);
rho ~ std_normal();
rho ~ inv_gamma(3,3);
// rho ~ std_normal();
// alpha ~ std_normal();
alpha ~ inv_gamma(1.5,1.5);
alpha ~ inv_gamma(3,3);
sigma ~ student_t(3, 0, .2);

y ~ multi_normal_cholesky(mu, L_K);
Expand All @@ -48,9 +48,16 @@ generated quantities {
{
matrix[N, N_pred] K_x_x_pred = gp_exp_quad_cov(x, x_pred, alpha, rho);
vector[N] K_div_y = mdivide_right_tri_low(mdivide_left_tri_low(L_K, y)', L_K)';
f_mean = K_x_x_pred' * K_div_y;
vector[N_pred] f_mu = K_x_x_pred' * K_div_y;
matrix[N, N_pred] v_pred = mdivide_left_tri_low(L_K, K_x_x_pred);
matrix[N_pred, N_pred] cov_f2 = gp_exp_quad_cov(x_pred, alpha, rho) - v_pred' * v_pred;

f_mean = multi_normal_rng(f_mu, add_diag(cov_f2, rep_vector(delta, N_pred)));

for (i in 1:N_pred) {
f_star[i] = normal_rng(f_mean[i], sigma);
}
}


}

0 comments on commit ab7b41d

Please sign in to comment.