From 2ecfef54f55b4e5e58a07e2336b11d9cf4c9c2fd Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 14 Aug 2024 22:55:19 +0100 Subject: [PATCH 01/10] setup broadcasting for rw --- inst/stan/functions/rt.stan | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/inst/stan/functions/rt.stan b/inst/stan/functions/rt.stan index 418c07675..c81494fbd 100644 --- a/inst/stan/functions/rt.stan +++ b/inst/stan/functions/rt.stan @@ -3,21 +3,17 @@ vector update_Rt(int t, real log_R, vector noise, array[] int bps, array[] real bp_effects, int stationary) { // define control parameters int bp_n = num_elements(bp_effects); - int bp_c = 0; int gp_n = num_elements(noise); // define result vectors vector[t] bp = rep_vector(0, t); - vector[t] gp = rep_vector(0, t); + vector[t] gp; vector[t] R; // initialise breakpoints if (bp_n) { - for (s in 1:t) { - if (bps[s]) { - bp_c += bps[s]; - bp[s] = bp_effects[bp_c]; - } + vector[t] bp = rep_vector(0, t); + if (bp_n) { + bp = cumulative_sum({0, bp_effects}[bps + 1]); } - bp = cumulative_sum(bp); } //initialise gaussian process if (gp_n) { @@ -32,10 +28,7 @@ vector update_Rt(int t, real log_R, vector noise, array[] int bps, gp = cumulative_sum(gp); } } - // Calculate Rt - R = rep_vector(log_R, t) + bp + gp; - R = exp(R); - return(R); + return(exp(rep_vector(log_R, t) + bp + gp)); } // Rt priors void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_growth, From a2d48a73ef9e986bcb0583dc72dc527e6c875509 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 14 Aug 2024 23:43:03 +0100 Subject: [PATCH 02/10] make wider changes to make broadcasting approach easier --- inst/stan/estimate_infections.stan | 9 +++++++-- inst/stan/functions/rt.stan | 24 +++++++++++++----------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 1e97203e4..d50af2692 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -31,6 +31,11 @@ transformed data{ // Rt real r_logmean = log(r_mean^2 / sqrt(r_sd^2 + r_mean^2)); real r_logsd = sqrt(log(1 + (r_sd^2 / r_mean^2))); + // Setup RW + array[bp_n] int breakpoints_; + for (i in 1:bp_n) { + breakpoints_[i] = breakpoints[i] + 1; + } array[delay_types] int delay_type_max; profile("assign max") { @@ -51,7 +56,7 @@ parameters{ array[estimate_r] real initial_infections ; // seed infections array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate array[bp_n > 0 ? 1 : 0] real bp_sd; // standard deviation of breakpoint effect - array[bp_n] real bp_effects; // Rt breakpoint effects + vector[bp_n] bp_effects; // Rt breakpoint effects // observation model vector[delay_params_length] delay_params; // delay parameters @@ -85,7 +90,7 @@ transformed parameters { } profile("R") { R = update_Rt( - ot_h, log_R[estimate_r], noise, breakpoints, bp_effects, stationary + ot_h, log_R[estimate_r], noise, breakpoints_, bp_effects, stationary ); } profile("infections") { diff --git a/inst/stan/functions/rt.stan b/inst/stan/functions/rt.stan index c81494fbd..be860212d 100644 --- a/inst/stan/functions/rt.stan +++ b/inst/stan/functions/rt.stan @@ -1,22 +1,22 @@ // update a vector of Rts vector update_Rt(int t, real log_R, vector noise, array[] int bps, - array[] real bp_effects, int stationary) { + vector bp_effects, int stationary) { // define control parameters int bp_n = num_elements(bp_effects); int gp_n = num_elements(noise); - // define result vectors - vector[t] bp = rep_vector(0, t); - vector[t] gp; - vector[t] R; + // Set up Rt intercept + vector[t] R = rep_vector(log_R, t); + // initialise breakpoints if (bp_n) { - vector[t] bp = rep_vector(0, t); - if (bp_n) { - bp = cumulative_sum({0, bp_effects}[bps + 1]); - } + vector[bp_n + 1] bp0; + bp0[1] = 0; + bp0[2:(bp_n + 1)] = bp_effects; + R = R + cumulative_sum(bp0[bps]); } //initialise gaussian process if (gp_n) { + vector[t] gp; if (stationary) { gp[1:gp_n] = noise; // fix future gp based on last estimated @@ -24,15 +24,17 @@ vector update_Rt(int t, real log_R, vector noise, array[] int bps, gp[(gp_n + 1):t] = rep_vector(noise[gp_n], t - gp_n); } } else { + gp[1] = 0; gp[2:(gp_n + 1)] = noise; gp = cumulative_sum(gp); } + R = R + gp; } - return(exp(rep_vector(log_R, t) + bp + gp)); + return(exp(R)); } // Rt priors void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_growth, - array[] real bp_effects, array[] real bp_sd, int bp_n, int seeding_time, + vector bp_effects, array[] real bp_sd, int bp_n, int seeding_time, real r_logmean, real r_logsd, real prior_infections, real prior_growth) { // prior on R From 6c758d3fc95a0f1c9f8ddf3b7f626d328fdc2bb2 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 15 Aug 2024 13:17:35 +0100 Subject: [PATCH 03/10] update interface and test code --- R/create.R | 23 +++++++++---- inst/stan/estimate_infections.stan | 7 +--- inst/stan/functions/rt.stan | 52 ++++++++++++++++++++++-------- tests/testthat/test-stan-rt.R | 12 +++---- 4 files changed, 62 insertions(+), 32 deletions(-) diff --git a/R/create.R b/R/create.R index 89c61a67c..d9d45992d 100644 --- a/R/create.R +++ b/R/create.R @@ -261,9 +261,13 @@ create_future_rt <- function(future = c("latest", "project", "estimate"), #' #' # using breakpoints #' create_rt_data(rt_opts(use_breakpoints = TRUE), breakpoints = rep(1, 10)) +#' +#' # using random walk +#' create_rt_data(rt_opts(rw = 7), breakpoints = rep(1, 10)) #' } create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, delay = 0, horizon = 0) { + # Define if GP is on or off if (is.null(rt)) { rt <- rt_opts( @@ -279,24 +283,29 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, ) # apply random walk if (rt$rw != 0) { - breakpoints <- as.integer(seq_along(breakpoints) %% rt$rw == 0) + breakpoints <- seq_along(breakpoints) + breakpoints <- floor(breakpoints / rt$rw) if (!(rt$future == "project")) { max_bps <- length(breakpoints) - horizon + future_rt$from if (max_bps < length(breakpoints)) { - breakpoints[(max_bps + 1):length(breakpoints)] <- 0 + breakpoints[(max_bps + 1):length(breakpoints)] <- breakpoints[max_bps] } } + }else { + if (is.null(breakpoints) || sum(breakpoints) == 0) { + rt$use_breakpoints <- FALSE + } + breakpoints <- cumsum(breakpoints) } - # check breakpoints - if (is.null(breakpoints) || sum(breakpoints) == 0) { - rt$use_breakpoints <- FALSE - } + # add a shift for 0 effect in breakpoints + breakpoints <- breakpoints + 1 + # map settings to underlying gp stan requirements rt_data <- list( r_mean = rt$prior$mean, r_sd = rt$prior$sd, estimate_r = as.numeric(rt$use_rt), - bp_n = ifelse(rt$use_breakpoints, sum(breakpoints, na.rm = TRUE), 0), + bp_n = ifelse(rt$use_breakpoints, max(breakpoints) - 1, 0), breakpoints = breakpoints, future_fixed = as.numeric(future_rt$fixed), fixed_from = future_rt$from, diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index d50af2692..985cfa9c2 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -31,11 +31,6 @@ transformed data{ // Rt real r_logmean = log(r_mean^2 / sqrt(r_sd^2 + r_mean^2)); real r_logsd = sqrt(log(1 + (r_sd^2 / r_mean^2))); - // Setup RW - array[bp_n] int breakpoints_; - for (i in 1:bp_n) { - breakpoints_[i] = breakpoints[i] + 1; - } array[delay_types] int delay_type_max; profile("assign max") { @@ -90,7 +85,7 @@ transformed parameters { } profile("R") { R = update_Rt( - ot_h, log_R[estimate_r], noise, breakpoints_, bp_effects, stationary + ot_h, log_R[estimate_r], noise, breakpoints, bp_effects, stationary ); } profile("infections") { diff --git a/inst/stan/functions/rt.stan b/inst/stan/functions/rt.stan index be860212d..8d3a4ca5c 100644 --- a/inst/stan/functions/rt.stan +++ b/inst/stan/functions/rt.stan @@ -1,25 +1,35 @@ -// update a vector of Rts +/** + * Update a vector of effective reproduction numbers (Rt) based on + * an intercept, breakpoints (i.e. a random walk), and a Gaussian + * process. + * + * @param t Length of the time series + * @param log_R Logarithm of the base reproduction number + * @param noise Vector of Gaussian process noise values + * @param bps Array of breakpoint indices + * @param bp_effects Vector of breakpoint effects + * @param stationary Flag indicating whether the Gaussian process is stationary + * (1) or non-stationary (0) + * @return A vector of length t containing the updated Rt values + */ vector update_Rt(int t, real log_R, vector noise, array[] int bps, vector bp_effects, int stationary) { - // define control parameters int bp_n = num_elements(bp_effects); int gp_n = num_elements(noise); - // Set up Rt intercept + vector[t] R = rep_vector(log_R, t); - // initialise breakpoints if (bp_n) { vector[bp_n + 1] bp0; bp0[1] = 0; - bp0[2:(bp_n + 1)] = bp_effects; - R = R + cumulative_sum(bp0[bps]); + bp0[2:(bp_n + 1)] = cumulative_sum(bp_effects); + R = R + bp0[bps]; } - //initialise gaussian process + if (gp_n) { vector[t] gp; if (stationary) { gp[1:gp_n] = noise; - // fix future gp based on last estimated if (t > gp_n) { gp[(gp_n + 1):t] = rep_vector(noise[gp_n], t - gp_n); } @@ -30,22 +40,38 @@ vector update_Rt(int t, real log_R, vector noise, array[] int bps, } R = R + gp; } - return(exp(R)); + + return exp(R); } -// Rt priors + +/** + * Calculate the log-probability of the reproduction number (Rt) priors + * + * @param log_R Logarithm of the base reproduction number + * @param initial_infections Array of initial infection values + * @param initial_growth Array of initial growth rates + * @param bp_effects Vector of breakpoint effects + * @param bp_sd Array of breakpoint standard deviations + * @param bp_n Number of breakpoints + * @param seeding_time Time point at which seeding occurs + * @param r_logmean Log-mean of the prior distribution for the base reproduction number + * @param r_logsd Log-standard deviation of the prior distribution for the base reproduction number + * @param prior_infections Prior mean for initial infections + * @param prior_growth Prior mean for initial growth rates + */ void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_growth, vector bp_effects, array[] real bp_sd, int bp_n, int seeding_time, real r_logmean, real r_logsd, real prior_infections, real prior_growth) { - // prior on R log_R ~ normal(r_logmean, r_logsd); - //breakpoint effects on Rt + if (bp_n > 0) { bp_sd[1] ~ normal(0, 0.1) T[0,]; bp_effects ~ normal(0, bp_sd[1]); } - // initial infections + initial_infections ~ normal(prior_infections, 0.2); + if (seeding_time > 1) { initial_growth ~ normal(prior_growth, 0.2); } diff --git a/tests/testthat/test-stan-rt.R b/tests/testthat/test-stan-rt.R index cddfa1ef9..1b4c40153 100644 --- a/tests/testthat/test-stan-rt.R +++ b/tests/testthat/test-stan-rt.R @@ -32,29 +32,29 @@ test_that("update_Rt works when Rt is fixed", { }) test_that("update_Rt works when Rt is fixed but a breakpoint is present", { expect_equal( - round(update_Rt(5, log(1.2), numeric(0), c(0, 0, 1, 0, 0), 0.1, 0), 2), + round(update_Rt(5, log(1.2), numeric(0), c(1, 1, 2, 2, 2), 0.1, 0), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), numeric(0), c(0, 0, 1, 0, 0), 0.1, 1), 2), + round(update_Rt(5, log(1.2), numeric(0), c(1, 1, 2, 2, 2), 0.1, 1), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), numeric(0), c(0, 1, 1, 0, 0), rep(0.1, 2), 0), 2), + round(update_Rt(5, log(1.2), numeric(0), c(1, 2, 3, 3, 3), rep(0.1, 2), 0), 2), c(1.2, 1.33, rep(1.47, 3)) ) }) test_that("update_Rt works when Rt is variable and a breakpoint is present", { expect_equal( - round(update_Rt(5, log(1.2), rep(0, 4), c(0, 0, 1, 0, 0), 0.1, 0), 2), + round(update_Rt(5, log(1.2), rep(0, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), rep(0, 5), c(0, 0, 1, 0, 0), 0.1, 1), 2), + round(update_Rt(5, log(1.2), rep(0, 5), c(1, 1, 2, 2, 2), 0.1, 1), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), rep(0.1, 4), c(0, 0, 1, 0, 0), 0.1, 0), 2), + round(update_Rt(5, log(1.2), rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), c(1.20, 1.33, 1.62, 1.79, 1.98) ) }) From f9df44552e3ad43d56aa76a5649c349696071a81 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 15 Aug 2024 15:02:56 +0100 Subject: [PATCH 04/10] add unit tests for Rt_opts --- tests/testthat/setup.R | 2 +- tests/testthat/test-rt_opts.R | 61 +++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 tests/testthat/test-rt_opts.R diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index e5e31d564..3e099a807 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -9,7 +9,7 @@ if (identical(Sys.getenv("NOT_CRAN"), "true")) { if (!(tolower(Sys.info()[["sysname"]]) %in% "windows")) { suppressMessages( expose_stan_fns(files, - target_dir = system.file("stan/functions", package = "EpiNow2") + target_dir = "inst/stan/functions" ) ) } diff --git a/tests/testthat/test-rt_opts.R b/tests/testthat/test-rt_opts.R new file mode 100644 index 000000000..2d415ff8e --- /dev/null +++ b/tests/testthat/test-rt_opts.R @@ -0,0 +1,61 @@ +test_that("rt_opts returns expected default values", { + result <- rt_opts() + + expect_s3_class(result, "rt_opts") + expect_equal(result$prior, list(mean = 1, sd = 1)) + expect_true(result$use_rt) + expect_equal(result$rw, 0) + expect_true(result$use_breakpoints) + expect_equal(result$future, "latest") + expect_equal(result$pop, 0) + expect_equal(result$gp_on, "R_t-1") +}) + +test_that("rt_opts handles custom inputs correctly", { + result <- rt_opts( + prior = list(mean = 2, sd = 0.5), + use_rt = FALSE, + rw = 7, + use_breakpoints = FALSE, + future = "project", + gp_on = "R0", + pop = 1000000 + ) + + expect_equal(result$prior, list(mean = 2, sd = 0.5)) + expect_false(result$use_rt) + expect_equal(result$rw, 7) + expect_true(result$use_breakpoints) # Should be TRUE when rw > 0 + expect_equal(result$future, "project") + expect_equal(result$pop, 1000000) + expect_equal(result$gp_on, "R0") +}) + +test_that("rt_opts sets use_breakpoints to TRUE when rw > 0", { + result <- rt_opts(rw = 3, use_breakpoints = FALSE) + expect_true(result$use_breakpoints) +}) + +test_that("rt_opts throws error for invalid prior", { + expect_error(rt_opts(prior = list(mean = 1)), + "prior must have both a mean and sd specified") + expect_error(rt_opts(prior = list(sd = 1)), + "prior must have both a mean and sd specified") +}) + +test_that("rt_opts validates gp_on argument", { + expect_error(rt_opts(gp_on = "invalid"), "must be one") +}) + +test_that("rt_opts returns object of correct class", { + result <- rt_opts() + expect_s3_class(result, "rt_opts") + expect_true("list" %in% class(result)) +}) + +test_that("rt_opts handles edge cases correctly", { + result <- rt_opts(rw = 0.1, pop = -1) + expect_equal(result$rw, 0.1) + expect_equal(result$pop, -1) + expect_true(result$use_breakpoints) +}) From f8a4beda1d7cb05432f21cf04e1d5501ff9330e2 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 15 Aug 2024 15:38:57 +0100 Subject: [PATCH 05/10] write unit tests for create_rt_data --- R/create.R | 14 +++-- tests/testthat/test-create_rt_date.R | 88 ++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 4 deletions(-) create mode 100644 tests/testthat/test-create_rt_date.R diff --git a/R/create.R b/R/create.R index d9d45992d..d3e8bedc6 100644 --- a/R/create.R +++ b/R/create.R @@ -273,7 +273,8 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, rt <- rt_opts( use_rt = FALSE, future = "project", - gp_on = "R0" + gp_on = "R0", + rw = 0 ) } # define future Rt arguments @@ -283,6 +284,10 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, ) # apply random walk if (rt$rw != 0) { + if (is.null(breakpoints)) { + stop("breakpoints must be supplied when using random walk") + } + breakpoints <- seq_along(breakpoints) breakpoints <- floor(breakpoints / rt$rw) if (!(rt$future == "project")) { @@ -292,11 +297,12 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, } } }else { - if (is.null(breakpoints) || sum(breakpoints) == 0) { - rt$use_breakpoints <- FALSE - } breakpoints <- cumsum(breakpoints) } + + if (sum(breakpoints) == 0) { + rt$use_breakpoints <- FALSE + } # add a shift for 0 effect in breakpoints breakpoints <- breakpoints + 1 diff --git a/tests/testthat/test-create_rt_date.R b/tests/testthat/test-create_rt_date.R new file mode 100644 index 000000000..748ae80d4 --- /dev/null +++ b/tests/testthat/test-create_rt_date.R @@ -0,0 +1,88 @@ +test_that("create_rt_data returns expected default values", { + result <- create_rt_data() + + expect_type(result, "list") + expect_equal(result$r_mean, 1) + expect_equal(result$r_sd, 1) + expect_equal(result$estimate_r, 1) + expect_equal(result$bp_n, 0) + expect_equal(result$breakpoints, numeric(0)) + expect_equal(result$future_fixed, 1) + expect_equal(result$fixed_from, 0) + expect_equal(result$pop, 0) + expect_equal(result$stationary, 0) + expect_equal(result$future_time, 0) +}) + +test_that("create_rt_data handles NULL rt input correctly", { + result <- create_rt_data(rt = NULL) + + expect_equal(result$estimate_r, 0) + expect_equal(result$future_fixed, 0) + expect_equal(result$stationary, 1) +}) + +test_that("create_rt_data handles custom rt_opts correctly", { + custom_rt <- rt_opts( + prior = list(mean = 2, sd = 0.5), + use_rt = FALSE, + rw = 0, + use_breakpoints = FALSE, + future = "project", + gp_on = "R0", + pop = 1000000 + ) + + result <- create_rt_data(rt = custom_rt, horizon = 7) + + expect_equal(result$r_mean, 2) + expect_equal(result$r_sd, 0.5) + expect_equal(result$estimate_r, 0) + expect_equal(result$pop, 1000000) + expect_equal(result$stationary, 1) + expect_equal(result$future_time, 7) +}) + +test_that("create_rt_data handles breakpoints correctly", { + result <- create_rt_data(rt_opts(use_breakpoints = TRUE), + breakpoints = c(1, 0, 1, 0, 1)) + + expect_equal(result$bp_n, 3) + expect_equal(result$breakpoints, c(2, 2, 3, 3, 4)) +}) + +test_that("create_rt_data handles random walk correctly", { + result <- create_rt_data(rt_opts(rw = 2), + breakpoints = rep(1, 10)) + + expect_equal(result$bp_n, 5) + expect_equal(result$breakpoints, c(1, 2, 2, 3, 3, 4, 4, 5, 5, 6)) +}) + +test_that("create_rt_data throws error for invalid inputs", { + expect_error(create_rt_data(rt_opts(rw = 2)), + "breakpoints must be supplied when using random walk") +}) + +test_that("create_rt_data handles future projections correctly", { + result <- create_rt_data(rt_opts(future = "project"), horizon = 7) + + expect_equal(result$future_fixed, 0) + expect_equal(result$fixed_from, 0) + expect_equal(result$future_time, 7) +}) + +test_that("create_rt_data handles zero sum breakpoints", { + result <- create_rt_data(rt_opts(use_breakpoints = TRUE), + breakpoints = rep(0, 5)) + + expect_equal(result$bp_n, 0) +}) + +test_that("create_rt_data adjusts breakpoints for horizon", { + result <- create_rt_data(rt_opts(rw = 2, future = "latest"), + breakpoints = rep(1, 10), + horizon = 3) + + expect_equal(result$breakpoints, c(1, 2, 2, 3, 3, 4, 4, 4, 4, 4)) +}) From 0f0eca41d5fc9fb813c2549eb08fa114e86f819b Mon Sep 17 00:00:00 2001 From: Sam Abbott Date: Thu, 15 Aug 2024 16:51:09 +0100 Subject: [PATCH 06/10] Update NEWS.md --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index 556bf50e4..9e0ac1bc2 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,7 @@ - The interface for defining delay distributions has been generalised to also cater for continuous distributions - When defining probability distributions these can now be truncated using the `tolerance` argument - Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @. +- Switch to broadcasting from random walks and added unit tests. By @seabbs in #747 and reviewed by @jamesmbaazam. ## Bug fixes @@ -23,6 +24,7 @@ - Updated the documentation of the dots argument of the `stan_sampling_opts()` to add that the dots are passed to `cmdstanr::sample()`. By @jamesmbaazam in #699 and reviewed by @sbfnk. - `generation_time_opts()` has been shortened to `gt_opts()` to make it easier to specify. Calls to both functions are equivalent. By @jamesmbaazam in #698 and reviewed by @seabbs and @sbfnk . +- Added stan documentation for `update_rt`. By @seabbs in #747 and reviewed by @jamesmbaazam. # EpiNow2 1.5.2 From 55c2bb71378f47d3d9df8e2ae86a8544740a40d7 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 15 Aug 2024 16:52:39 +0100 Subject: [PATCH 07/10] revert setup changes --- tests/testthat/setup.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 3e099a807..e5e31d564 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -9,7 +9,7 @@ if (identical(Sys.getenv("NOT_CRAN"), "true")) { if (!(tolower(Sys.info()[["sysname"]]) %in% "windows")) { suppressMessages( expose_stan_fns(files, - target_dir = "inst/stan/functions" + target_dir = system.file("stan/functions", package = "EpiNow2") ) ) } From 434923329dc9356fc191df0099e7ec936f6d8f7d Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 15 Aug 2024 17:38:12 +0100 Subject: [PATCH 08/10] update docs --- man/EpiNow2-package.Rd | 2 +- man/create_rt_data.Rd | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/man/EpiNow2-package.Rd b/man/EpiNow2-package.Rd index 3b1d4e286..f8f7529f3 100644 --- a/man/EpiNow2-package.Rd +++ b/man/EpiNow2-package.Rd @@ -43,7 +43,7 @@ Other contributors: \item Paul Mee \email{paul.mee@lshtm.ac.uk} [contributor] \item Peter Ellis \email{peter.ellis2013nz@gmail.com} [contributor] \item Pietro Monticone \email{pietro.monticone@edu.unito.it} [contributor] - \item Lloyd Chapman \email{lloyd.chapman1@lshtm.ac.uk} [contributor] + \item Lloyd Chapman \email{lloyd.chapman1@lshtm.ac.uk } [contributor] \item Andrew Johnson \email{andrew.johnson@arjohnsonau.com} [contributor] } diff --git a/man/create_rt_data.Rd b/man/create_rt_data.Rd index 4a02c1664..3c5233795 100644 --- a/man/create_rt_data.Rd +++ b/man/create_rt_data.Rd @@ -36,6 +36,9 @@ create_rt_data(rt = NULL) # using breakpoints create_rt_data(rt_opts(use_breakpoints = TRUE), breakpoints = rep(1, 10)) + +# using random walk +create_rt_data(rt_opts(rw = 7), breakpoints = rep(1, 10)) } } \seealso{ From 500b11023f53964762842197d2dc53d681edc6af Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 15 Aug 2024 20:47:19 +0100 Subject: [PATCH 09/10] catch initialisation issue --- inst/stan/functions/rt.stan | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/inst/stan/functions/rt.stan b/inst/stan/functions/rt.stan index 8d3a4ca5c..ad2d877b1 100644 --- a/inst/stan/functions/rt.stan +++ b/inst/stan/functions/rt.stan @@ -14,27 +14,28 @@ */ vector update_Rt(int t, real log_R, vector noise, array[] int bps, vector bp_effects, int stationary) { + // define control parameters int bp_n = num_elements(bp_effects); int gp_n = num_elements(noise); - + // initialise intercept vector[t] R = rep_vector(log_R, t); - + //initialise breakpoints + rw if (bp_n) { vector[bp_n + 1] bp0; bp0[1] = 0; bp0[2:(bp_n + 1)] = cumulative_sum(bp_effects); R = R + bp0[bps]; } - + //initialise gaussian process if (gp_n) { - vector[t] gp; + vector[t] gp = rep_vector(0, t); if (stationary) { gp[1:gp_n] = noise; + // fix future gp based on last estimated if (t > gp_n) { gp[(gp_n + 1):t] = rep_vector(noise[gp_n], t - gp_n); } } else { - gp[1] = 0; gp[2:(gp_n + 1)] = noise; gp = cumulative_sum(gp); } @@ -64,12 +65,12 @@ void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_g real r_logmean, real r_logsd, real prior_infections, real prior_growth) { log_R ~ normal(r_logmean, r_logsd); - + //breakpoint effects on Rt if (bp_n > 0) { bp_sd[1] ~ normal(0, 0.1) T[0,]; bp_effects ~ normal(0, bp_sd[1]); } - + // initial infections initial_infections ~ normal(prior_infections, 0.2); if (seeding_time > 1) { From 7822fe8b28dfcc2b2ec970df59d554cbde17bba3 Mon Sep 17 00:00:00 2001 From: Sam Abbott Date: Wed, 28 Aug 2024 15:38:51 +0100 Subject: [PATCH 10/10] Update NEWS.md Co-authored-by: James Azam --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 0ae098d5c..08cf8a70a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -28,7 +28,7 @@ - Updated the documentation of the dots argument of the `stan_sampling_opts()` to add that the dots are passed to `cmdstanr::sample()`. By @jamesmbaazam in #699 and reviewed by @sbfnk. - `generation_time_opts()` has been shortened to `gt_opts()` to make it easier to specify. Calls to both functions are equivalent. By @jamesmbaazam in #698 and reviewed by @seabbs and @sbfnk . -- Added stan documentation for `update_rt`. By @seabbs in #747 and reviewed by @jamesmbaazam. +- Added stan documentation for `update_rt()`. By @seabbs in #747 and reviewed by @jamesmbaazam. # EpiNow2 1.5.2