diff --git a/R/run.R b/R/run.R index fae5c98..770412d 100644 --- a/R/run.R +++ b/R/run.R @@ -115,17 +115,21 @@ run_gp = function(dat, ..., max_grid_size = 2000, check_df(dat) - if (!is.null(param_grid)) { + if (is.null(param_grid)) { cr_res = check_ranges(dat, param_ranges) dat = cr_res[[1]] param_ranges = cr_res[[2]] + x_grid = get_x_grid(max_grid_size, param_ranges, param_grid) + x_grid_cent = center_grid(x_grid, param_ranges) + centered_dat = center_dat(dat, param_ranges) + } else { + # TODO: make sure this handles uneven user-provided grids correctly. + emp_ranges = param_grid |> lapply(frange) |> qDT() + x_grid = param_grid + x_grid_cent = param_grid |> center_grid(emp_ranges) + centered_dat = center_dat(dat, emp_ranges) } - x_grid = get_x_grid(max_grid_size, param_ranges, param_grid) - x_grid_cent = center_grid(x_grid, param_ranges) - - centered_dat = center_dat(dat, param_ranges) - X = centered_dat |> get_vars("_cent", regex = TRUE) |> qM() list(run_gp_model(X = X, y = dat$rating, X_pred = x_grid_cent, ...), @@ -153,7 +157,7 @@ run_gp = function(dat, ..., max_grid_size = 2000, suggest_next = function(dat, ..., max_grid_size = 2000, param_ranges = create_ranges(), param_grid = NULL, offset = .25, - lambda = .01) { + lambda = .1) { run_res = run_gp(dat, max_grid_size = max_grid_size, diff --git a/man/suggest_next.Rd b/man/suggest_next.Rd index 950f6bb..057df53 100644 --- a/man/suggest_next.Rd +++ b/man/suggest_next.Rd @@ -11,7 +11,7 @@ suggest_next( param_ranges = create_ranges(), param_grid = NULL, offset = 0.25, - lambda = 0.01 + lambda = 0.1 ) } \arguments{ diff --git a/src/stan/gp_mod.stan b/src/stan/gp_mod.stan index 2ca4597..e74349f 100644 --- a/src/stan/gp_mod.stan +++ b/src/stan/gp_mod.stan @@ -33,10 +33,10 @@ transformed parameters { } model { - // rho ~ inv_gamma(3,3); - rho ~ std_normal(); + rho ~ inv_gamma(3,3); + // rho ~ std_normal(); // alpha ~ std_normal(); - alpha ~ inv_gamma(1.5,1.5); + alpha ~ inv_gamma(3,3); sigma ~ student_t(3, 0, .2); y ~ multi_normal_cholesky(mu, L_K); @@ -48,9 +48,16 @@ generated quantities { { matrix[N, N_pred] K_x_x_pred = gp_exp_quad_cov(x, x_pred, alpha, rho); vector[N] K_div_y = mdivide_right_tri_low(mdivide_left_tri_low(L_K, y)', L_K)'; - f_mean = K_x_x_pred' * K_div_y; + vector[N_pred] f_mu = K_x_x_pred' * K_div_y; + matrix[N, N_pred] v_pred = mdivide_left_tri_low(L_K, K_x_x_pred); + matrix[N_pred, N_pred] cov_f2 = gp_exp_quad_cov(x_pred, alpha, rho) - v_pred' * v_pred; + + f_mean = multi_normal_rng(f_mu, add_diag(cov_f2, rep_vector(delta, N_pred))); + for (i in 1:N_pred) { f_star[i] = normal_rng(f_mean[i], sigma); } } + + }