Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move lengthscale prior to dist_spec #890

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

- A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs.
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and reviewed by @seabbs.
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and #890 and reviewed by @seabbs.
- The Gaussian Process lengthscale is now scaled by half the length of the time series. By @sbfnk in #890 and reviewed by #

## Package changes

Expand Down
41 changes: 13 additions & 28 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -362,31 +362,14 @@ create_gp_data <- function(gp = gp_opts(), data) {
time <- time - 1
}

obs_time <- data$t - data$seeding_time
if (gp$ls_max > obs_time) {
gp$ls_max <- obs_time
}

times <- seq_len(time)

rescaled_times <- (times - mean(times)) / sd(times)
gp$ls_mean <- gp$ls_mean / sd(times)
gp$ls_sd <- gp$ls_sd / sd(times)
gp$ls_min <- gp$ls_min / sd(times)
gp$ls_max <- gp$ls_max / sd(times)

# basis functions
M <- ceiling(time * gp$basis_prop)

# map settings to underlying gp stan requirements
gp_data <- list(
fixed = as.numeric(fixed),
M = M,
L = gp$boundary_scale * max(rescaled_times),
ls_meanlog = convert_to_logmean(gp$ls_mean, gp$ls_sd),
ls_sdlog = convert_to_logsd(gp$ls_mean, gp$ls_sd),
ls_min = gp$ls_min,
ls_max = gp$ls_max,
L = gp$boundary_scale,
gp_type = data.table::fcase(
gp$kernel == "se", 0,
gp$kernel == "periodic", 1,
Expand Down Expand Up @@ -528,6 +511,16 @@ create_stan_data <- function(data, seeding_time,
# gaussian process data
stan_data <- create_gp_data(gp, stan_data)

## process legacy GP arguments (deprecated and will be removed)
if (!is.null(gp) && gp$legacy_arguments) {
scale <- 0.5 * (time - 1)
ls_meanlog <- convert_to_logmean(gp$ls_mean, gp$ls_sd) / scale
ls_sdlog <- convert_to_logsd(gp$ls_mean, gp$ls_sd) / scale
ls_max <- gp$ls_max / scale

gp$ls <- LogNormal(ls_meanlog, ls_sdlog, max = ls_max)
}

# observation model data
stan_data <- c(
stan_data,
Expand All @@ -542,11 +535,13 @@ create_stan_data <- function(data, seeding_time,
stan_data,
create_stan_params(
alpha = gp$alpha,
rescaled_rho = gp$ls,
R0 = rt$prior,
frac_obs = obs$scale,
rep_phi = obs$phi,
lower_bounds = c(
alpha = 0,
rescaled_rho = 0,
R0 = 0,
frac_obs = 0,
rep_phi = 0
Expand Down Expand Up @@ -601,18 +596,8 @@ create_initial_conditions <- function(data) {
if (data$fixed == 0) {
out$eta <- array(rnorm(
ifelse(data$gp_type == 1, data$M * 2, data$M), mean = 0, sd = 0.1))
out$rescaled_rho <- array(rlnorm(1,
meanlog = data$ls_meanlog,
sdlog = ifelse(data$ls_sdlog > 0, data$ls_sdlog, 0.01)
))
out$rescaled_rho <- array(data.table::fcase(
out$rescaled_rho > data$ls_max, data$ls_max - 0.001,
out$rescaled_rho < data$ls_min, data$ls_min + 0.001,
default = out$rescaled_rho
))
} else {
out$eta <- array(numeric(0))
out$rescaled_rho <- array(numeric(0))
}
if (data$estimate_r == 1) {
out$initial_infections <- array(rnorm(1, data$prior_infections, 0.2))
Expand Down
57 changes: 46 additions & 11 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -461,19 +461,20 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#' Defines a list specifying the structure of the approximate Gaussian
#' process. Custom settings can be supplied which override the defaults.
#'
#' @param ls_mean Numeric, defaults to 21 days. The mean of the lognormal
#' length scale.
#' @param ls_mean Deprecated; use `ls` instead.
#'
#' @param ls_sd Numeric, defaults to 7 days. The standard deviation of the log
#' normal length scale. If \code{ls_sd = 0}, inverse-gamma prior on Gaussian
#' process length scale will be used with recommended parameters
#' \code{inv_gamma(1.499007, 0.057277 * ls_max)}.
#' @param ls_sd Deprecated; use `ls` instead.
#'
#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale.
#' @param ls_min Deprecated; use `ls` instead.
#'
#' @param ls_max Numeric, defaults to 60. The maximum value of the length
#' scale. Updated in [create_gp_data()] to be the length of the input data if
#' this is smaller.
#' @param ls_max Deprecated; use `ls` instead.
#'
#' @param ls A `<dist_spec>` giving the prior distribution of the lengthscale
#' parameter of the Gaussian process kernel. This is scaled with half the
#' length of the time scale such that 2 corresponds to the length of the time
#' series. Defaults to a half-normal distribution with mean 0.5, sd 0.1 and
#' maximum 1: `Normal(mean = 0.5, sd = 0.1, max = 1)` (a lower limit of 0 will
#' be enforced automatically to ensure positivity)
#'
#' @param alpha A `<dist_spec>` giving the prior distribution of the magnitude
#' parameter of the Gaussian process kernel. Should be approximately the
Expand Down Expand Up @@ -533,6 +534,7 @@ gp_opts <- function(basis_prop = 0.2,
ls_sd = 7,
ls_min = 0,
ls_max = 60,
ls = Normal(mean = 0.5, sd = 0.1, max = 1),
alpha = Normal(mean = 0, sd = 0.01),
kernel = c("matern", "se", "ou", "periodic"),
matern_order = 3 / 2,
Expand All @@ -555,6 +557,37 @@ gp_opts <- function(basis_prop = 0.2,
"1.7.0", "gp_opts(alpha_sd)", "gp_opts(alpha)"
)
}
if (!missing(ls_mean) || !missing(ls_sd) || !missing(ls_min) ||
!missing(ls_max)) {
if (!missing(ls)) {
cli_abort(
c(
"!" = "Both {.var ls} and at least one legacy argument
({.var ls_mean}, {.var ls_sd}, {.var ls_min}, {.var ls_max}) have been
specified.",
"i" = "Only one of the should be used."
)
)
}
cli_warn(c(
"!" = "Specifying lengthscale priors via the {.var ls_mean}, {.var ls_sd},
{.var ls_min}, and {.var ls_max} arguments is deprecated.",
"i" = "Use the {.var ls} argument instead."
))
if (ls_min > 0) {
cli_abort(
c(
"!" = "Lower lengthscale bounds of greater than 0 are no longer
supported. If this is a feature you need please open an Issue on the
EpiNow2 GitHub repository."
)
)
}
legacy_arguments <- TRUE
} else {
legacy_arguments <- FALSE
}


if (!missing(matern_type)) {
if (!missing(matern_order) && matern_type != matern_order) {
Expand Down Expand Up @@ -592,10 +625,12 @@ gp_opts <- function(basis_prop = 0.2,
ls_sd = ls_sd,
ls_min = ls_min,
ls_max = ls_max,
ls = ls,
alpha = alpha,
kernel = kernel,
matern_order = matern_order,
w0 = w0
w0 = w0,
legacy_arguments = legacy_arguments
)

attr(gp, "class") <- c("gp_opts", class(gp))
Expand Down
1 change: 1 addition & 0 deletions R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ simulate_infections <- function(estimates, R, initial_infections,

data <- c(data, create_stan_params(
alpha = NULL,
rescaled_rho = NULL,
R0 = NULL,
frac_obs = obs$scale,
rep_phi = obs$phi
Expand Down
1 change: 1 addition & 0 deletions inst/stan/data/estimate_infections_params.stan
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
int<lower = 0> alpha_id; // parameter id of alpha (GP magnitude)
int<lower = 0> rescaled_rho_id; // parameter id of rescaled rho (GP lengthscale)
int<lower = 0> R0_id; // parameter id of R0
int<lower = 0> frac_obs_id; // parameter id of frac_obs
int<lower = 0> rep_phi_id; // parameter id of rep_phi_id
4 changes: 0 additions & 4 deletions inst/stan/data/gaussian_process.stan
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
real L; // boundary value for infections gp
int<lower=1> M; // basis functions for infections gp
real ls_meanlog; // meanlog for gp lengthscale prior
real ls_sdlog; // sdlog for gp lengthscale prior
real<lower=0> ls_min; // Lower bound for the lengthscale
real<lower=0> ls_max; // Upper bound for the lengthscale
int gp_type; // type of gp, 0 = squared exponential, 1 = periodic, 2 = Matern
real nu; // smoothness parameter for Matern kernel (used if gp_type = 2)
real w0; // fundamental frequency for periodic kernel (used if gp_type = 1)
Expand Down
16 changes: 10 additions & 6 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ transformed data {
parameters {
vector<lower = params_lower, upper = params_upper>[n_params_variable] params;
// gaussian process
array[fixed ? 0 : 1] real<lower = ls_min, upper = ls_max> rescaled_rho; // length scale of noise GP
vector[fixed ? 0 : gp_type == 1 ? 2*M : M] eta; // unconstrained noise
// Rt
array[estimate_r] real initial_infections; // seed infections
Expand All @@ -70,6 +69,10 @@ transformed parameters {
alpha_id, params_fixed_lookup, params_variable_lookup, params_value,
params
);
real rescaled_rho = get_param(
rescaled_rho_id, params_fixed_lookup, params_variable_lookup,
params_value, params
);
noise = update_gp(
PHI, M, L, alpha, rescaled_rho, eta, gp_type, nu
);
Expand Down Expand Up @@ -176,9 +179,6 @@ model {
if (!fixed) {
profile("gp lp") {
gaussian_process_lp(eta);
if (gp_type != 3) {
lengthscale_lp(rescaled_rho[1], ls_meanlog, ls_sdlog, ls_min, ls_max);
}
}
}

Expand Down Expand Up @@ -233,9 +233,13 @@ generated quantities {
rep_phi_id, params_fixed_lookup, params_variable_lookup, params_value,
params
);
if (!fixed && gp_type != 3) {
if (!fixed) {
real rescaled_rho = get_param(
rescaled_rho_id, params_fixed_lookup, params_variable_lookup,
params_value, params
);
vector[noise_terms] x = linspaced_vector(noise_terms, 1, noise_terms);
rho[1] = rescaled_rho[1] * sd(x);
rho[1] = rescaled_rho * 0.5 * (max(x) - 1);
}

if (estimate_r == 0) {
Expand Down
32 changes: 7 additions & 25 deletions inst/stan/functions/gaussian_process.stan
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ int setup_noise(int ot_h, int t, int horizon, int estimate_r,
*/
matrix setup_gp(int M, real L, int dimension, int is_periodic, real w0) {
vector[dimension] x = linspaced_vector(dimension, 1, dimension);
x = (x - mean(x)) / sd(x);
x = 2 * (x - mean(x)) / (max(x) - 1);
if (is_periodic) {
return PHI_periodic(dimension, M, w0, x);
} else {
Expand All @@ -165,46 +165,28 @@ matrix setup_gp(int M, real L, int dimension, int is_periodic, real w0) {
* @return A vector of updated noise terms
*/
vector update_gp(matrix PHI, int M, real L, real alpha,
array[] real rho, vector eta, int type, real nu) {
real rho, vector eta, int type, real nu) {
vector[type == 1 ? 2 * M : M] diagSPD; // spectral density

// GP in noise - spectral densities
if (type == 0) {
diagSPD = diagSPD_EQ(alpha, rho[1], L, M);
diagSPD = diagSPD_EQ(alpha, rho, L, M);
} else if (type == 1) {
diagSPD = diagSPD_Periodic(alpha, rho[1], M);
diagSPD = diagSPD_Periodic(alpha, rho, M);
} else if (type == 2) {
if (nu == 0.5) {
diagSPD = diagSPD_Matern12(alpha, rho[1], L, M);
diagSPD = diagSPD_Matern12(alpha, rho, L, M);
} else if (nu == 1.5) {
diagSPD = diagSPD_Matern32(alpha, rho[1], L, M);
diagSPD = diagSPD_Matern32(alpha, rho, L, M);
} else if (nu == 2.5) {
diagSPD = diagSPD_Matern52(alpha, rho[1], L, M);
diagSPD = diagSPD_Matern52(alpha, rho, L, M);
} else {
reject("nu must be one of 1/2, 3/2 or 5/2; found nu=", nu);
}
}
return PHI * (diagSPD .* eta);
}

/**
* Prior for Gaussian process length scale
*
* @param rho Length scale parameter
* @param ls_meanlog Mean of the log of the length scale
* @param ls_sdlog Standard deviation of the log of the length scale
* @param ls_min Minimum length scale
* @param ls_max Maximum length scale
*/
void lengthscale_lp(real rho, real ls_meanlog, real ls_sdlog,
real ls_min, real ls_max) {
if (ls_sdlog > 0) {
rho ~ lognormal(ls_meanlog, ls_sdlog) T[ls_min, ls_max];
} else {
rho ~ inv_gamma(1.499007, 0.057277 * ls_max) T[ls_min, ls_max];
}
}

/**
* Priors for Gaussian process (excluding length scale)
*
Expand Down
22 changes: 12 additions & 10 deletions man/gp_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 2 additions & 14 deletions tests/testthat/test-create_gp_data.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
test_that("create_gp_data returns correct default values when GP is disabled", {
data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 0, fixed_from = 0)
restricted_time <- 30 - 7 - 1
times <- seq_len(restricted_time)
gp_data <- create_gp_data(NULL, data)
expect_equal(gp_data$fixed, 1)
expect_equal(gp_data$stationary, 1)
expect_equal(gp_data$M, 5) # (30 - 7) * 0.2
expect_equal(gp_data$L, 2.43, tolerance = 0.01)
expect_equal(gp_data$ls_meanlog, convert_to_logmean(21 / sd(times), 7 / sd(times)))
expect_equal(gp_data$ls_sdlog, convert_to_logsd(21, 7))
expect_equal(gp_data$ls_min, 0)
expect_equal(gp_data$ls_max, 3.54, tolerance = 0.01)
expect_equal(gp_data$L, 1.5)
expect_equal(gp_data$alpha, NULL)
expect_equal(gp_data$rescaled_rho, NULL)
expect_equal(gp_data$gp_type, 2) # Default to Matern
expect_equal(gp_data$nu, 3 / 2)
expect_equal(gp_data$w0, 1.0)
Expand All @@ -37,13 +32,6 @@ test_that("create_gp_data sets correct gp_type and nu for different kernels", {
expect_equal(gp_data$nu, 1 / 2)
})

test_that("create_gp_data correctly adjusts ls_max", {
data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 0, fixed_from = 0, stationary = 0)
gp <- gp_opts(ls_max = 50)
gp_data <- create_gp_data(gp, data)
expect_equal(gp_data$ls_max, 3.39, tolerance = 0.01) # 30 - 7 - 7
})

test_that("create_gp_data correctly handles future_fixed", {
data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 1, fixed_from = 2, stationary = 0)
gp_data <- create_gp_data(gp_opts(), data)
Expand Down
Loading
Loading