Skip to content

Commit

Permalink
add linear kernel support
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Aug 15, 2024
1 parent 40aae17 commit f781c3c
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 30 deletions.
1 change: 1 addition & 0 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ create_gp_data <- function(gp = gp_opts(), data) {
gp$kernel == "se", 0,
gp$kernel == "periodic", 1,
gp$kernel == "matern" || gp$kernel == "ou", 2,
gp$kernel == "linear", 3,
default = 2
),
nu = gp$matern_order,
Expand Down
23 changes: 16 additions & 7 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -404,18 +404,19 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#' process. Custom settings can be supplied which override the defaults.
#'
#' @param ls_mean Numeric, defaults to 21 days. The mean of the lognormal
#' length scale.
#' length scale. Not used for linear kernel.
#'
#' @param ls_sd Numeric, defaults to 7 days. The standard deviation of the log
#' normal length scale. If \code{ls_sd = 0}, inverse-gamma prior on Gaussian
#' process length scale will be used with recommended parameters
#' \code{inv_gamma(1.499007, 0.057277 * ls_max)}.
#' \code{inv_gamma(1.499007, 0.057277 * ls_max)}. Not used for linear kernel.
#'
#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale.
#' Not used for linear kernel.
#'
#' @param ls_max Numeric, defaults to 60. The maximum value of the length
#' scale. Updated in [create_gp_data()] to be the length of the input data if
#' this is smaller.
#'
#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale.
#' this is smaller. Not used for linear kernel.
#'
#' @param alpha_mean Numeric, defaults to 0. The mean of the magnitude parameter
#' of the Gaussian process kernel. Should be approximately the expected variance
Expand All @@ -427,7 +428,8 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#'
#' @param kernel Character string, the type of kernel required. Currently
#' supporting the squared exponential kernel ("se"), periodic kernel
#' ("periodic"), Ornstein-Uhlenbeck kernel ("ou"), and Matern kernel ("matern").
#' ("periodic"), Ornstein-Uhlenbeck kernel ("ou"), Matern kernel ("matern"),
#' and linear kernel ("linear").
#'
#' @param matern_order Numeric, defaults to 3/2. Order of Matérn Kernel to use.
#' Common choices are 1/2, 3/2, and 5/2. If `kernel` is set
Expand Down Expand Up @@ -460,6 +462,9 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#'
#' # add a custom length scale
#' gp_opts(ls_mean = 4)
#'
#' # use linear kernel
#' gp_opts(kernel = "linear")
gp_opts <- function(basis_prop = 0.2,
boundary_scale = 1.5,
ls_mean = 21,
Expand All @@ -468,7 +473,7 @@ gp_opts <- function(basis_prop = 0.2,
ls_max = 60,
alpha_mean = 0,
alpha_sd = 0.01,
kernel = c("matern", "se", "ou", "periodic"),
kernel = c("matern", "se", "ou", "periodic", "linear"),
matern_order = 3 / 2,
matern_type,
w0 = 1.0) {
Expand All @@ -494,6 +499,10 @@ gp_opts <- function(basis_prop = 0.2,
matern_order <- Inf
} else if (kernel == "ou") {
matern_order <- 1 / 2
} else if (kernel == "linear") {
if (ls_mean != 21 || ls_sd != 7 || ls_min != 0 || ls_max != 60) {
warning("Length scale parameters are not used for the linear kernel.")
}
} else if (
!(is.infinite(matern_order) || matern_order %in% c(1 / 2, 3 / 2, 5 / 2))
) {
Expand Down
1 change: 1 addition & 0 deletions inst/stan/data/gaussian_process.stan
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
real ls_sdlog; // sdlog for gp lengthscale prior
real<lower=0> ls_min; // Lower bound for the lengthscale
real<lower=0> ls_max; // Upper bound for the lengthscale
real alpha_mean; // mean of the alpha gp kernal parameter
real alpha_sd; // standard deviation of the alpha gp kernal parameter
int gp_type; // type of gp, 0 = squared exponential, 1 = periodic, 2 = Matern
real nu; // smoothness parameter for Matern kernel (used if gp_type = 2)
Expand Down
12 changes: 6 additions & 6 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ transformed data {

parameters {
// gaussian process
array[fixed ? 0 : 1] real<lower = ls_min, upper = ls_max> rescaled_rho; // length scale of noise GP
array[fixed || gp_type == 3 ? 0 : 1] real<lower = ls_min, upper = ls_max> rescaled_rho; // length scale of noise GP
array[fixed ? 0 : 1] real<lower = 0> alpha; // scale of noise GP
vector[fixed ? 0 : M] eta; // unconstrained noise
// Rt
Expand Down Expand Up @@ -70,7 +70,7 @@ transformed parameters {
profile("update gp") {
if (!fixed) {
noise = update_gp(
PHI, M, L, alpha[1], rescaled_rho[1], eta, gp_type, nu
PHI, M, L, alpha[1], rescaled_rho, eta, gp_type, nu
);
}
}
Expand Down Expand Up @@ -162,10 +162,10 @@ model {
// priors for noise GP
if (!fixed) {
profile("gp lp") {
gaussian_process_lp(
rescaled_rho[1], alpha[1], eta, ls_meanlog, ls_sdlog, ls_min,
ls_max, alpha_sd
);
gaussian_process_lp(alpha[1], eta, alpha_mean, alpha_sd);
if (gp_type != 3) {
lengthscale_lp(rescaled_rho[1], ls_meanlog, ls_sdlog, ls_min, ls_max);
}
}
}

Expand Down
51 changes: 37 additions & 14 deletions inst/stan/functions/gaussian_process.stan
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,26 @@ vector diagSPD_Matern(real nu, real alpha, real rho, real L, int M) {
* @param M Number of basis functions
* @return A vector of spectral densities
*/
vector diagSPD_periodic(real alpha, real rho, int M) {
vector diagSPD_Periodic(real alpha, real rho, int M) {
real a = inv_square(rho);
vector[M] indices = linspaced_vector(M, 1, M);
vector[M] q = exp(log(alpha) + 0.5 * (log(2) - a + to_vector(log_modified_bessel_first_kind(indices, a))));
return append_row(q, q);
}

/**
* Spectral density for Linear kernel
*
* @param alpha Scaling parameter
* @param L Length of the interval
* @param M Number of basis functions
* @return A vector of spectral densities
*/
vector diagSPD_Linear(real alpha, real L, int M) {
vector[M] indices = linspaced_vector(M, 1, M);
return alpha * square(L) / (square(pi()) * square(indices));
}

/**
* Basis functions for Gaussian Process
*
Expand Down Expand Up @@ -129,45 +142,55 @@ matrix setup_gp(int M, real L, int dimension, int is_periodic, real w0) {
* @param alpha Scaling parameter
* @param rho Length scale parameter
* @param eta Vector of noise terms
* @param type Type of kernel (0: SE, 1: Periodic, 2: Matern)
* @param type Type of kernel (0: SE, 1: Periodic, 2: Matern, 3: Linear)
* @param nu Smoothness parameter for Matern kernel
* @return A vector of updated noise terms
*/
vector update_gp(matrix PHI, int M, real L, real alpha,
real rho, vector eta, int type, real nu) {
array[] real rho, vector eta, int type, real nu) {
vector[M] diagSPD; // spectral density

// GP in noise - spectral densities
if (type == 0) {
diagSPD = diagSPD_EQ(alpha, rho, L, M);
diagSPD = diagSPD_EQ(alpha, rho[1], L, M);
} else if (type == 1) {
diagSPD = diagSPD_periodic(alpha, rho, M);
diagSPD = diagSPD_Periodic(alpha, rho[1], M);
} else if (type == 2) {
diagSPD = diagSPD_Matern(nu, alpha, rho, L, M);
diagSPD = diagSPD_Matern(nu, alpha, rho[1], L, M);
} else if (type == 3) {
diagSPD = diagSPD_Linear(alpha, L, M);
}
return PHI * (diagSPD .* eta);
}

/**
* Priors for Gaussian process
* Prior for Gaussian process length scale
*
* @param rho Length scale parameter
* @param alpha Scaling parameter
* @param eta Vector of noise terms
* @param ls_meanlog Mean of the log of the length scale
* @param ls_sdlog Standard deviation of the log of the length scale
* @param ls_min Minimum length scale
* @param ls_max Maximum length scale
* @param alpha_sd Standard deviation of alpha
*/
void gaussian_process_lp(real rho, real alpha, vector eta,
real ls_meanlog, real ls_sdlog,
real ls_min, real ls_max, real alpha_sd) {
void lengthscale_lp(real rho, real ls_meanlog, real ls_sdlog,
real ls_min, real ls_max) {
if (ls_sdlog > 0) {
rho ~ lognormal(ls_meanlog, ls_sdlog) T[ls_min, ls_max];
} else {
rho ~ inv_gamma(1.499007, 0.057277 * ls_max) T[ls_min, ls_max];
}
alpha ~ normal(0, alpha_sd) T[0,];
}

/**
* Priors for Gaussian process (excluding length scale)
*
* @param alpha Scaling parameter
* @param eta Vector of noise terms
* @param alpha_sd Standard deviation of alpha
*/
void gaussian_process_lp(real alpha, vector eta, real alpha_mean,
real alpha_sd) {
alpha ~ normal(alpha_mean, alpha_sd) T[0,];
eta ~ std_normal();
}

9 changes: 9 additions & 0 deletions tests/testthat/test-create_gp_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,13 @@ test_that("create_gp_data correctly handles future_fixed", {
data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 1, fixed_from = 2, stationary = 0)
gp_data <- create_gp_data(gp_opts(), data)
expect_equal(gp_data$M, 4)
})

test_that("create_gp_data correctly handles linear kernel", {
data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 0, fixed_from = 0, stationary = 0)
linear_gp_opts <- gp_opts(kernel = "linear")
gp_data <- create_gp_data(linear_gp_opts, data)

# Check that gp_type is set to 3 for linear kernel
expect_equal(gp_data$gp_type, 3)
})
25 changes: 25 additions & 0 deletions tests/testthat/test-gp_opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,29 @@ test_that("gp_opts stops for incompatible matern_order and matern_type", {

test_that("gp_opts warns about uncommon Matern kernel orders", {
expect_warning(gp_opts(matern_order = 2), "Uncommon Matern kernel order")
})

test_that("gp_opts handles linear kernel correctly", {
expect_silent(gp_opts(kernel = "linear"))
expect_warning(
gp_opts(kernel = "linear", ls_mean = 30),
"Length scale parameters are not used for the linear kernel"
)
expect_warning(
gp_opts(kernel = "linear", ls_sd = 10),
"Length scale parameters are not used for the linear kernel"
)
expect_warning(
gp_opts(kernel = "linear", ls_min = 1),
"Length scale parameters are not used for the linear kernel"
)
expect_warning(
gp_opts(kernel = "linear", ls_max = 100),
"Length scale parameters are not used for the linear kernel"
)

linear_gp <- gp_opts(kernel = "linear")
expect_true(all(c("ls_mean", "ls_sd", "ls_min", "ls_max") %in% names(linear_gp)))

expect_equal(linear_gp$kernel, "linear")
})
61 changes: 58 additions & 3 deletions tests/testthat/test-stan-guassian-process.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,27 @@ test_that("diagSPD_Matern returns correct dimensions and values", {
expect_equal(result, expected_result, tolerance = 1e-8)
})

test_that("diagSPD_periodic returns correct dimensions and values", {
test_that("diagSPD_Periodic returns correct dimensions and values", {
alpha <- 1.0
rho <- 2.0
M <- 5
result <- diagSPD_periodic(alpha, rho, M)
result <- diagSPD_Periodic(alpha, rho, M)
expect_equal(length(result), 2 * M) # Expect double the dimensions due to append_row
expect_true(all(result > 0)) # Expect spectral density to be positive
})

test_that("diagSPD_Linear returns correct dimensions and values", {
alpha <- 1.0
L <- 1.0
M <- 5
result <- diagSPD_Linear(alpha, L, M)
expect_equal(length(result), M)
expect_true(all(result > 0)) # Expect spectral density to be positive
# Check specific values for known inputs
indices <- linspaced_vector(M, 1, M)
expected_result <- alpha * L^2 / (pi^2 * indices^2)
expect_equal(result, expected_result, tolerance = 1e-8)
})

test_that("PHI returns correct dimensions and values", {
N <- 5
Expand Down Expand Up @@ -140,4 +152,47 @@ test_that("update_gp returns correct dimensions and values", {
diagSPD <- diagSPD_EQ(alpha, rho, L, M)
expected_result <- PHI %*% (diagSPD * eta)
expect_equal(matrix(result, ncol = 1), expected_result, tolerance = 1e-8)
})
})

test_that("update_gp with linear kernel returns correct dimensions and values", {
M <- 3
L <- 1.0
alpha <- 1.0
rho <- c(1.0) # Not used for linear kernel
eta <- rep(1, M)
PHI <- matrix(runif(15), nrow = 5) # 5 observations, 3 basis functions
type <- 3 # Linear kernel
nu <- 1.5 # Not used for linear kernel
result <- update_gp(PHI, M, L, alpha, rho, eta, type, nu)
expect_equal(length(result), nrow(PHI)) # Should match number of observations
# Check specific values for known inputs
diagSPD <- diagSPD_Linear(alpha, L, M)
expected_result <- PHI %*% (diagSPD * eta)
expect_equal(matrix(result, ncol = 1), expected_result, tolerance = 1e-8)
})

test_that("Linear kernel produces a linear GP", {
N <- 100
M <- 50
L <- 10.0
alpha <- 2.0
x <- seq(-L, L, length.out = N)

# Setup GP
PHI <- PHI(N, M, L, x)

# Generate random eta
set.seed(123)
eta <- rnorm(M)

# Compute GP
diagSPD <- diagSPD_Linear(alpha, L, M)
gp <- PHI %*% (diagSPD * eta)

# Fit a linear model
lm_fit <- lm(gp ~ x)

# Check if the GP is approximately linear
expect_gt(summary(lm_fit)$r.squared, 0.99) # R-squared should be very close to 1
expect_lt(summary(lm_fit)$sigma, 0.1) # Residual standard error should be small
})

0 comments on commit f781c3c

Please sign in to comment.