Skip to content

Commit

Permalink
tune prior specification
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Aug 13, 2024
1 parent 706b9de commit 92624fa
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 117 deletions.
11 changes: 7 additions & 4 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ create_gp_data <- function(gp = gp_opts(), data) {
ls_sdlog = convert_to_logsd(gp$ls_mean, gp$ls_sd),
ls_min = gp$ls_min,
ls_max = gp$ls_max,
alpha_mean = gp$alpha_mean,
alpha_sd = gp$alpha_sd,
gp_type = data.table::fcase(
gp$kernel == "se", 0,
Expand Down Expand Up @@ -621,7 +622,9 @@ create_initial_conditions <- function(data) {
))

out$alpha <- array(
truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = data$alpha_sd)
truncnorm::rtruncnorm(
1, a = 0, mean = data$alpha_mean, sd = data$alpha_sd
)
)
} else {
out$eta <- array(numeric(0))
Expand All @@ -632,7 +635,7 @@ create_initial_conditions <- function(data) {
out$rep_phi <- array(
truncnorm::rtruncnorm(
1,
a = 0, mean = data$phi_mean, sd = data$phi_sd / 10
a = 0, mean = data$phi_mean, sd = data$phi_sd
)
)
}
Expand All @@ -643,7 +646,7 @@ create_initial_conditions <- function(data) {
}
out$log_R <- array(rnorm(
n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd),
sd = convert_to_logsd(data$r_mean, data$r_sd) * 0.1
sd = convert_to_logsd(data$r_mean, data$r_sd)
))
}

Expand All @@ -658,7 +661,7 @@ create_initial_conditions <- function(data) {
out$frac_obs <- array(truncnorm::rtruncnorm(1,
a = 0, b = 1,
mean = data$obs_scale_mean,
sd = data$obs_scale_sd * 0.1
sd = data$obs_scale_sd
))
} else {
out$frac_obs <- array(numeric(0))
Expand Down
6 changes: 6 additions & 0 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,10 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#'
#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale.
#'
#' @param alpha_mean Numeric, defaults to 0. The mean of the magnitude parameter
#' of the Gaussian process kernel. Should be approximately the expected variance
#' of the logged Rt.
#'
#' @param alpha_sd Numeric, defaults to 0.05. The standard deviation of the
#' magnitude parameter of the Gaussian process kernel. Should be approximately
#' the expected standard deviation of the logged Rt.
Expand Down Expand Up @@ -462,6 +466,7 @@ gp_opts <- function(basis_prop = 0.2,
ls_sd = 7,
ls_min = 0,
ls_max = 60,
alpha_mean = 0,
alpha_sd = 0.025,
kernel = c("matern", "se", "ou", "periodic"),
matern_order = 3 / 2,
Expand Down Expand Up @@ -501,6 +506,7 @@ gp_opts <- function(basis_prop = 0.2,
ls_sd = ls_sd,
ls_min = ls_min,
ls_max = ls_max,
alpha_mean = alpha_mean,
alpha_sd = alpha_sd,
kernel = kernel,
matern_order = matern_order,
Expand Down
Binary file modified vignettes/EpiNow2-unnamed-chunk-11-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
214 changes: 103 additions & 111 deletions vignettes/EpiNow2.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Load example case data from `{EpiNow2}`.


``` r
reported_cases <- example_confirmed[1:60]
reported_cases <- example_confirmed[1:90]
head(reported_cases)
#> date confirm
#> <Date> <num>
Expand All @@ -118,92 +118,92 @@ estimates <- epinow(
generation_time = gt_opts(example_generation_time),
delays = delay_opts(example_incubation_period + reporting_delay),
rt = rt_opts(prior = list(mean = 2, sd = 0.2)),
stan = stan_opts(cores = 4, backend = "cmdstanr", save_warmup = TRUE),
stan = stan_opts(cores = 4, backend = "cmdstanr", save_warmup = TRUE, control = list(adapt_delta = 0.8)),
gp = gp_opts(alpha_sd = 0.01),
verbose = interactive()
)
#> DEBUG [2024-08-13 14:25:11] epinow: Running in exact mode for samples (across 4 chains each with a warm up of iterations each) and 81 time steps of which 7 are a forecast
#> DEBUG [2024-08-13 15:55:38] epinow: Running in exact mode for samples (across 4 chains each with a warm up of iterations each) and 111 time steps of which 7 are a forecast
#> Running MCMC with 4 parallel chains...
#>
#> Chain 1 Iteration: 1 / 750 [ 0%] (Warmup)
#> Chain 2 Iteration: 1 / 750 [ 0%] (Warmup)
#> Chain 2 Iteration: 50 / 750 [ 6%] (Warmup)
#> Chain 3 Iteration: 1 / 750 [ 0%] (Warmup)
#> Chain 4 Iteration: 1 / 750 [ 0%] (Warmup)
#> Chain 4 Iteration: 50 / 750 [ 6%] (Warmup)
#> Chain 1 Iteration: 50 / 750 [ 6%] (Warmup)
#> Chain 3 Iteration: 50 / 750 [ 6%] (Warmup)
#> Chain 4 Iteration: 50 / 750 [ 6%] (Warmup)
#> Chain 2 Iteration: 100 / 750 [ 13%] (Warmup)
#> Chain 1 Iteration: 100 / 750 [ 13%] (Warmup)
#> Chain 2 Iteration: 150 / 750 [ 20%] (Warmup)
#> Chain 1 Iteration: 150 / 750 [ 20%] (Warmup)
#> Chain 4 Iteration: 100 / 750 [ 13%] (Warmup)
#> Chain 2 Iteration: 150 / 750 [ 20%] (Warmup)
#> Chain 2 Iteration: 200 / 750 [ 26%] (Warmup)
#> Chain 3 Iteration: 100 / 750 [ 13%] (Warmup)
#> Chain 1 Iteration: 200 / 750 [ 26%] (Warmup)
#> Chain 4 Iteration: 150 / 750 [ 20%] (Warmup)
#> Chain 3 Iteration: 100 / 750 [ 13%] (Warmup)
#> Chain 1 Iteration: 100 / 750 [ 13%] (Warmup)
#> Chain 4 Iteration: 200 / 750 [ 26%] (Warmup)
#> Chain 2 Iteration: 250 / 750 [ 33%] (Warmup)
#> Chain 2 Iteration: 251 / 750 [ 33%] (Sampling)
#> Chain 1 Iteration: 250 / 750 [ 33%] (Warmup)
#> Chain 1 Iteration: 251 / 750 [ 33%] (Sampling)
#> Chain 3 Iteration: 150 / 750 [ 20%] (Warmup)
#> Chain 4 Iteration: 200 / 750 [ 26%] (Warmup)
#> Chain 2 Iteration: 300 / 750 [ 40%] (Sampling)
#> Chain 1 Iteration: 300 / 750 [ 40%] (Sampling)
#> Chain 3 Iteration: 200 / 750 [ 26%] (Warmup)
#> Chain 2 Iteration: 350 / 750 [ 46%] (Sampling)
#> Chain 1 Iteration: 150 / 750 [ 20%] (Warmup)
#> Chain 4 Iteration: 250 / 750 [ 33%] (Warmup)
#> Chain 4 Iteration: 251 / 750 [ 33%] (Sampling)
#> Chain 1 Iteration: 350 / 750 [ 46%] (Sampling)
#> Chain 2 Iteration: 400 / 750 [ 53%] (Sampling)
#> Chain 3 Iteration: 200 / 750 [ 26%] (Warmup)
#> Chain 1 Iteration: 200 / 750 [ 26%] (Warmup)
#> Chain 2 Iteration: 350 / 750 [ 46%] (Sampling)
#> Chain 4 Iteration: 300 / 750 [ 40%] (Sampling)
#> Chain 1 Iteration: 400 / 750 [ 53%] (Sampling)
#> Chain 1 Iteration: 250 / 750 [ 33%] (Warmup)
#> Chain 1 Iteration: 251 / 750 [ 33%] (Sampling)
#> Chain 2 Iteration: 400 / 750 [ 53%] (Sampling)
#> Chain 3 Iteration: 250 / 750 [ 33%] (Warmup)
#> Chain 3 Iteration: 251 / 750 [ 33%] (Sampling)
#> Chain 2 Iteration: 450 / 750 [ 60%] (Sampling)
#> Chain 4 Iteration: 350 / 750 [ 46%] (Sampling)
#> Chain 1 Iteration: 450 / 750 [ 60%] (Sampling)
#> Chain 2 Iteration: 500 / 750 [ 66%] (Sampling)
#> Chain 4 Iteration: 400 / 750 [ 53%] (Sampling)
#> Chain 1 Iteration: 300 / 750 [ 40%] (Sampling)
#> Chain 2 Iteration: 450 / 750 [ 60%] (Sampling)
#> Chain 3 Iteration: 300 / 750 [ 40%] (Sampling)
#> Chain 1 Iteration: 500 / 750 [ 66%] (Sampling)
#> Chain 2 Iteration: 550 / 750 [ 73%] (Sampling)
#> Chain 4 Iteration: 450 / 750 [ 60%] (Sampling)
#> Chain 1 Iteration: 550 / 750 [ 73%] (Sampling)
#> Chain 4 Iteration: 400 / 750 [ 53%] (Sampling)
#> Chain 1 Iteration: 350 / 750 [ 46%] (Sampling)
#> Chain 2 Iteration: 500 / 750 [ 66%] (Sampling)
#> Chain 3 Iteration: 350 / 750 [ 46%] (Sampling)
#> Chain 2 Iteration: 600 / 750 [ 80%] (Sampling)
#> Chain 4 Iteration: 500 / 750 [ 66%] (Sampling)
#> Chain 1 Iteration: 600 / 750 [ 80%] (Sampling)
#> Chain 4 Iteration: 450 / 750 [ 60%] (Sampling)
#> Chain 1 Iteration: 400 / 750 [ 53%] (Sampling)
#> Chain 2 Iteration: 550 / 750 [ 73%] (Sampling)
#> Chain 3 Iteration: 400 / 750 [ 53%] (Sampling)
#> Chain 2 Iteration: 650 / 750 [ 86%] (Sampling)
#> Chain 4 Iteration: 550 / 750 [ 73%] (Sampling)
#> Chain 1 Iteration: 650 / 750 [ 86%] (Sampling)
#> Chain 2 Iteration: 700 / 750 [ 93%] (Sampling)
#> Chain 1 Iteration: 700 / 750 [ 93%] (Sampling)
#> Chain 4 Iteration: 600 / 750 [ 80%] (Sampling)
#> Chain 4 Iteration: 500 / 750 [ 66%] (Sampling)
#> Chain 1 Iteration: 450 / 750 [ 60%] (Sampling)
#> Chain 2 Iteration: 600 / 750 [ 80%] (Sampling)
#> Chain 3 Iteration: 450 / 750 [ 60%] (Sampling)
#> Chain 2 Iteration: 750 / 750 [100%] (Sampling)
#> Chain 2 finished in 35.9 seconds.
#> Chain 1 Iteration: 750 / 750 [100%] (Sampling)
#> Chain 4 Iteration: 650 / 750 [ 86%] (Sampling)
#> Chain 1 finished in 36.2 seconds.
#> Chain 4 Iteration: 550 / 750 [ 73%] (Sampling)
#> Chain 1 Iteration: 500 / 750 [ 66%] (Sampling)
#> Chain 2 Iteration: 650 / 750 [ 86%] (Sampling)
#> Chain 3 Iteration: 500 / 750 [ 66%] (Sampling)
#> Chain 4 Iteration: 700 / 750 [ 93%] (Sampling)
#> Chain 4 Iteration: 600 / 750 [ 80%] (Sampling)
#> Chain 1 Iteration: 550 / 750 [ 73%] (Sampling)
#> Chain 2 Iteration: 700 / 750 [ 93%] (Sampling)
#> Chain 3 Iteration: 550 / 750 [ 73%] (Sampling)
#> Chain 4 Iteration: 750 / 750 [100%] (Sampling)
#> Chain 4 finished in 39.9 seconds.
#> Chain 1 Iteration: 600 / 750 [ 80%] (Sampling)
#> Chain 4 Iteration: 650 / 750 [ 86%] (Sampling)
#> Chain 2 Iteration: 750 / 750 [100%] (Sampling)
#> Chain 2 finished in 65.0 seconds.
#> Chain 1 Iteration: 650 / 750 [ 86%] (Sampling)
#> Chain 3 Iteration: 600 / 750 [ 80%] (Sampling)
#> Chain 4 Iteration: 700 / 750 [ 93%] (Sampling)
#> Chain 1 Iteration: 700 / 750 [ 93%] (Sampling)
#> Chain 3 Iteration: 650 / 750 [ 86%] (Sampling)
#> Chain 4 Iteration: 750 / 750 [100%] (Sampling)
#> Chain 4 finished in 71.6 seconds.
#> Chain 1 Iteration: 750 / 750 [100%] (Sampling)
#> Chain 1 finished in 74.7 seconds.
#> Chain 3 Iteration: 700 / 750 [ 93%] (Sampling)
#> Chain 3 Iteration: 750 / 750 [100%] (Sampling)
#> Chain 3 finished in 50.1 seconds.
#> Chain 3 finished in 79.7 seconds.
#>
#> All 4 chains finished successfully.
#> Mean chain execution time: 40.5 seconds.
#> Total execution time: 50.4 seconds.
#> Mean chain execution time: 72.8 seconds.
#> Total execution time: 79.9 seconds.
names(estimates)
#> [1] "estimates" "estimated_reported_cases"
#> [3] "summary" "plots"
#> [5] "timing"
#> [1] "estimates" "estimated_reported_cases" "summary"
#> [4] "plots" "timing"
```

Both summary measures and posterior samples are returned for all parameters in an easily explored format which can be accessed using `summary`. The default is to return a summary table of estimates for key parameters at the latest date partially supported by data.
Expand All @@ -215,13 +215,13 @@ knitr::kable(summary(estimates))



|measure |estimate |
|:--------------------------------|:----------------------|
|New infections per day |2260 (1123 -- 4514) |
|Expected change in daily reports |Likely decreasing |
|Effective reproduction no. |0.9 (0.66 -- 1.2) |
|Rate of growth |-0.03 (-0.13 -- 0.069) |
|Doubling/halving time (days) |-23 (10 -- -5.2) |
|measure |estimate |
|:--------------------------------|:-----------------------|
|New infections per day |466 (273 -- 859) |
|Expected change in daily reports |Likely decreasing |
|Effective reproduction no. |0.86 (0.68 -- 1.1) |
|Rate of growth |-0.044 (-0.11 -- 0.038) |
|Doubling/halving time (days) |-16 (18 -- -6.3) |



Expand All @@ -230,53 +230,45 @@ Summarised parameter estimates can also easily be returned, either filtered for

``` r
head(summary(estimates, type = "parameters", params = "R"))
#> date variable strat type median mean
#> <Date> <char> <char> <char> <num> <num>
#> 1: 2020-02-22 R <NA> estimate 2.188825 2.199833
#> 2: 2020-02-23 R <NA> estimate 2.146680 2.157505
#> 3: 2020-02-24 R <NA> estimate 2.104185 2.113811
#> 4: 2020-02-25 R <NA> estimate 2.055775 2.068968
#> 5: 2020-02-26 R <NA> estimate 2.011460 2.023196
#> 6: 2020-02-27 R <NA> estimate 1.962390 1.976722
#> sd lower_90 lower_50 lower_20 upper_20 upper_50
#> <num> <num> <num> <num> <num> <num>
#> 1: 0.2027924 1.882374 2.054785 2.136798 2.241456 2.331738
#> 2: 0.1958294 1.855209 2.020150 2.101472 2.197808 2.283783
#> 3: 0.1931172 1.815532 1.978267 2.055972 2.153300 2.240020
#> 4: 0.1933264 1.776407 1.933982 2.009352 2.108190 2.191375
#> 5: 0.1949870 1.728702 1.888055 1.962152 2.063184 2.142092
#> 6: 0.1968078 1.685707 1.842048 1.915718 2.010884 2.094873
#> upper_90
#> <num>
#> 1: 2.552528
#> 2: 2.496288
#> 3: 2.453176
#> 4: 2.400488
#> 5: 2.354851
#> 6: 2.313125
#> date variable strat type median mean sd lower_90 lower_50 lower_20
#> <Date> <char> <char> <char> <num> <num> <num> <num> <num> <num>
#> 1: 2020-02-22 R <NA> estimate 2.176950 2.184577 0.1987545 1.882753 2.039355 2.123910
#> 2: 2020-02-23 R <NA> estimate 2.131085 2.140555 0.1910809 1.854654 2.002222 2.083422
#> 3: 2020-02-24 R <NA> estimate 2.086505 2.095422 0.1853327 1.820169 1.962838 2.038096
#> 4: 2020-02-25 R <NA> estimate 2.039110 2.049337 0.1811049 1.778029 1.920557 1.990486
#> 5: 2020-02-26 R <NA> estimate 1.991250 2.002479 0.1779291 1.736063 1.877828 1.940956
#> 6: 2020-02-27 R <NA> estimate 1.941435 1.955039 0.1753258 1.692931 1.831690 1.893360
#> upper_20 upper_50 upper_90
#> <num> <num> <num>
#> 1: 2.224172 2.311905 2.529553
#> 2: 2.180886 2.264790 2.473291
#> 3: 2.134336 2.219297 2.424412
#> 4: 2.087078 2.169872 2.373263
#> 5: 2.038166 2.118070 2.324234
#> 6: 1.987172 2.067273 2.276156
```

Reported cases are returned in a separate data frame in order to streamline the reporting of forecasts and for model evaluation.


``` r
head(summary(estimates, output = "estimated_reported_cases"))
#> date type median mean sd lower_90
#> <Date> <char> <num> <num> <num> <num>
#> 1: 2020-02-22 gp_rt 73 74.6790 20.75308 45
#> 2: 2020-02-23 gp_rt 84 87.0700 24.59519 52
#> 3: 2020-02-24 gp_rt 83 85.5970 24.17349 51
#> 4: 2020-02-25 gp_rt 78 79.3880 21.86777 47
#> 5: 2020-02-26 gp_rt 77 79.6855 22.14036 48
#> 6: 2020-02-27 gp_rt 108 110.3890 30.11093 66
#> lower_50 lower_20 upper_20 upper_50 upper_90
#> <num> <num> <num> <num> <num>
#> 1: 60 67.6 78 87.00 112
#> 2: 70 79.0 90 102.00 131
#> 3: 69 77.0 89 100.00 129
#> 4: 63 72.0 84 93.25 116
#> 5: 64 72.0 82 93.00 119
#> 6: 89 101.0 115 128.00 162
#> date type median mean sd lower_90 lower_50 lower_20 upper_20 upper_50
#> <Date> <char> <num> <num> <num> <num> <num> <num> <num> <num>
#> 1: 2020-02-22 gp_rt 74.5 76.5690 19.68764 48.00 63 70 79.0 88
#> 2: 2020-02-23 gp_rt 82.0 83.4560 21.27288 52.00 68 77 87.4 97
#> 3: 2020-02-24 gp_rt 78.0 79.4430 20.82096 49.00 65 73 83.0 92
#> 4: 2020-02-25 gp_rt 71.0 72.9020 18.97486 44.95 59 67 76.0 85
#> 5: 2020-02-26 gp_rt 86.0 88.0270 22.32343 54.00 72 81 92.0 102
#> 6: 2020-02-27 gp_rt 113.0 114.3555 27.38125 73.00 95 106 120.0 131
#> upper_90
#> <num>
#> 1: 111.00
#> 2: 121.00
#> 3: 116.00
#> 4: 107.05
#> 5: 129.00
#> 6: 163.00
```

A range of plots are returned (with the single summary plot shown below). These plots can also be generated using the following `plot` method.
Expand Down Expand Up @@ -325,26 +317,26 @@ estimates <- regional_epinow(
gp = NULL,
stan = stan_opts(cores = 4, warmup = 250, samples = 1000)
)
#> INFO [2024-08-13 14:26:07] Producing following optional outputs: regions, summary, samples, plots, latest
#> INFO [2024-08-13 14:26:07] Reporting estimates using data up to: 2020-04-21
#> INFO [2024-08-13 14:26:07] No target directory specified so returning output
#> INFO [2024-08-13 14:26:07] Producing estimates for: testland, realland
#> INFO [2024-08-13 14:26:07] Regions excluded: none
#> INFO [2024-08-13 15:57:03] Producing following optional outputs: regions, summary, samples, plots, latest
#> INFO [2024-08-13 15:57:03] Reporting estimates using data up to: 2020-05-21
#> INFO [2024-08-13 15:57:03] No target directory specified so returning output
#> INFO [2024-08-13 15:57:03] Producing estimates for: testland, realland
#> INFO [2024-08-13 15:57:03] Regions excluded: none
#> Error in eval(expr, envir, enclos) :
#> Exception: variable does not exist; processing stage=data initialization; variable name=delay_mean_mean; base type=double (in 'estimate_infections', line 613, column 2 to column 40)
#> Error in eval(expr, envir, enclos) :
#> Exception: variable does not exist; processing stage=data initialization; variable name=delay_mean_mean; base type=double (in 'estimate_infections', line 613, column 2 to column 40)
#> DEBUG [2024-08-13 14:26:07] testland: Running in exact mode for 1000 samples (across 4 chains each with a warm up of 250 iterations each) and 81 time steps of which 7 are a forecast
#> DEBUG [2024-08-13 14:26:08] realland: Running in exact mode for 1000 samples (across 4 chains each with a warm up of 250 iterations each) and 81 time steps of which 7 are a forecast
#> INFO [2024-08-13 14:26:08] Completed regional estimates
#> INFO [2024-08-13 14:26:08] Regions with estimates: 0
#> INFO [2024-08-13 14:26:08] Regions with runtime errors: 2
#> INFO [2024-08-13 14:26:08] Producing summary
#> INFO [2024-08-13 14:26:08] No summary directory specified so returning summary output
#> INFO [2024-08-13 14:26:08] Errors caught whilst generating summary statistics:
#> INFO [2024-08-13 14:26:08] Error: Object 'variable' not found amongst []
#> DEBUG [2024-08-13 15:57:03] testland: Running in exact mode for 1000 samples (across 4 chains each with a warm up of 250 iterations each) and 111 time steps of which 7 are a forecast
#> DEBUG [2024-08-13 15:57:03] realland: Running in exact mode for 1000 samples (across 4 chains each with a warm up of 250 iterations each) and 111 time steps of which 7 are a forecast
#> INFO [2024-08-13 15:57:03] Completed regional estimates
#> INFO [2024-08-13 15:57:03] Regions with estimates: 0
#> INFO [2024-08-13 15:57:03] Regions with runtime errors: 2
#> INFO [2024-08-13 15:57:03] Producing summary
#> INFO [2024-08-13 15:57:03] No summary directory specified so returning summary output
#> INFO [2024-08-13 15:57:03] Errors caught whilst generating summary statistics:
#> INFO [2024-08-13 15:57:03] Error: Object 'variable' not found amongst []
#>
#> INFO [2024-08-13 14:26:08] No target directory specified so returning timings
#> INFO [2024-08-13 15:57:03] No target directory specified so returning timings
```

Results from each region are stored in a `regional` list with across region summary measures and plots stored in a `summary` list. All results can be set to be internally saved by setting the `target_folder` and `summary_dir` arguments. Each region can be estimated in parallel using the `{future}` package (when in most scenarios `cores` should be set to 1). For routine use each MCMC chain can also be run in parallel (with `future` = TRUE) with a time out (`max_execution_time`) allowing for partial results to be returned if a subset of chains is running longer than expected. See the documentation for the `{future}` package for details on nested futures.
Expand Down
Loading

0 comments on commit 92624fa

Please sign in to comment.