Skip to content

Commit

Permalink
fix estimate_secondary without delays (#588)
Browse files Browse the repository at this point in the history
* pull scale/conv out of calculate_secondary

* only convolve if delayed

* add regression test

* update stan secondary tests
  • Loading branch information
sbfnk authored Feb 29, 2024
1 parent f0e9c6f commit e1ff1f5
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 42 deletions.
21 changes: 16 additions & 5 deletions inst/stan/estimate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,31 @@ transformed parameters {
// calculate secondary reports from primary

{
vector[delay_type_max[delay_id] + 1] delay_rev_pmf;
vector[t] scaled;
vector[t] convolved = rep_vector(1e-5, t);

// scaling of primary reports by fraction observed
if (obs_scale) {
scaled = scale_obs(primary, obs_scale_sd > 0 ? frac_obs[1] : obs_scale_mean);
} else {
scaled = primary;
}

if (delay_id) {
delay_rev_pmf = get_delay_rev_pmf(
vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf(
delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_mean, delay_sd, delay_dist,
0, 1, 0
);
convolved = convolved + convolve_to_report(scaled, delay_rev_pmf, 0);
} else {
delay_rev_pmf = to_vector({ 1 });
convolved = convolved + scaled;
}

secondary = calculate_secondary(
primary, obs, frac_obs, delay_rev_pmf, cumulative, historic,
primary_hist_additive, current, primary_current_additive, t
scaled, convolved, obs, cumulative, historic, primary_hist_additive,
current, primary_current_additive, t
);
}

Expand Down
26 changes: 5 additions & 21 deletions inst/stan/functions/secondary.stan
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
// Calculate secondary reports condition only on primary reports
vector calculate_secondary(vector reports, array[] int obs, array[] real frac_obs,
vector delay_rev_pmf, int cumulative, int historic,
int primary_hist_additive, int current,
int primary_current_additive, int predict) {
int t = num_elements(reports);
int obs_scale = num_elements(frac_obs);
vector[t] scaled_reports;
vector[t] conv_reports = rep_vector(1e-5, t);
vector calculate_secondary(vector scaled_reports, vector conv_reports, array[] int obs,
int cumulative, int historic, int primary_hist_additive,
int current, int primary_current_additive, int predict) {
int t = num_elements(scaled_reports);
vector[t] secondary_reports = rep_vector(0.0, t);
// scaling of reported cases by fraction
if (obs_scale) {
scaled_reports = scale_obs(reports, frac_obs[1]);
}else{
scaled_reports = reports;
}
// convolve from reports to contributions from these reports
conv_reports = conv_reports +
convolve_to_report(scaled_reports, delay_rev_pmf, 0);
// if predicting and using a cumulative target
// combine reports with previous secondary data
for (i in 1:t) {
Expand All @@ -33,10 +20,7 @@ vector calculate_secondary(vector reports, array[] int obs, array[] real frac_ob
if (primary_hist_additive) {
secondary_reports[i] += conv_reports[i];
}else{
if (conv_reports[i] > secondary_reports[i]) {
conv_reports[i] = secondary_reports[i];
}
secondary_reports[i] -= conv_reports[i];
secondary_reports[i] = fmax(0, secondary_reports[i] - conv_reports[i]);
}
}
// update based on current primary reports
Expand Down
29 changes: 21 additions & 8 deletions inst/stan/simulate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,33 @@ generated quantities {
array[n, all_dates ? t : h] int sim_secondary;
for (i in 1:n) {
vector[t] secondary;
vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf(
vector[t] scaled;
vector[t] convolved = rep_vector(1e-5, t);

if (obs_scale) {
scaled = scale_obs(to_vector(primary[i]), frac_obs[i, 1]);
} else {
scaled = to_vector(primary[i]);
}

if (delay_id) {
vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf(
delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_mean[i], delay_sd[i], delay_dist,
0, 1, 0
);
);
convolved = convolved + convolve_to_report(scaled, delay_rev_pmf, 0);
} else {
convolved = convolved + scaled;
}

// calculate secondary reports from primary
secondary =
calculate_secondary(
to_vector(primary[i]), obs, frac_obs[i], delay_rev_pmf, cumulative,
historic, primary_hist_additive, current, primary_current_additive,
t - h + 1
);
secondary = calculate_secondary(
scaled, convolved, obs, cumulative, historic, primary_hist_additive,
current, primary_current_additive, t - h + 1
);

// weekly reporting effect
if (week_effect > 1) {
secondary = day_of_week_effect(secondary, day_of_week, to_vector(day_of_week_simplex[i]));
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test-estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ test_that("estimate_secondary successfully returns estimates when accumulating t
expect_true(is.list(inc_weekly$data))
})

test_that("estimate_secondary works when only estimating scaling", {
inc <- estimate_secondary(inc_cases[1:60],
obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE),
delay = delay_opts(),
verbose = FALSE
)
expect_equal(names(inc), c("predictions", "posterior", "data", "fit"))
})

test_that("estimate_secondary can recover simulated parameters", {
expect_equal(
inc_posterior[, mean], c(1.8, 0.5, 0.4),
Expand Down
18 changes: 10 additions & 8 deletions tests/testthat/test-stan-secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ skip_on_os("windows")
# test primary reports and observations
reports <- rep(10, 20)
obs <- rep(4, 20)
delay_pmf <- reverse_mf(discretised_pmf(log(3), 0.1, 5, 0))
delay_rev_pmf <- reverse_mf(discretised_pmf(log(3), 0.1, 5, 0))
scaled <- reports * 0.1
convolved <- rep(1e-5, 20) + convolve_to_report(scaled, delay_rev_pmf, 0)

check_equal <- function(args, target, dof = 0, dev = FALSE) {
out <- do.call(calculate_secondary, args)
Expand All @@ -17,46 +19,46 @@ check_equal <- function(args, target, dof = 0, dev = FALSE) {

test_that("calculate_secondary can calculate prevalence as expected", {
check_equal(
args = list(reports, obs, 0.1, delay_pmf, 1, 1, 1, 1, 1, 20),
args = list(scaled, convolved, obs, 1, 1, 1, 1, 1, 20),
target = c(1, 5, 5.5, rep(6, 17)), dof = 1
)
})

test_that("calculate_secondary can calculate incidence as expected", {
check_equal(
args = list(reports, obs, 0.1, delay_pmf, 0, 1, 1, 1, 1, 20),
args = list(scaled, convolved, obs, 0, 1, 1, 1, 1, 20),
target = c(1, 1, 1.5, rep(2.0, 17)), dof = 1
)
})

test_that("calculate_secondary can calculate incidence as expected", {
check_equal(
args = list(reports, obs, 0.1, delay_pmf, 0, 1, 1, 1, 1, 20),
args = list(scaled, convolved, obs, 0, 1, 1, 1, 1, 20),
target = c(1, 1, 1.5, rep(2.0, 17)), dof = 1
)
})

test_that("calculate_secondary can calculate incidence using only historic reports", {
check_equal(
args = list(reports, obs, 0.1, delay_pmf, 0, 1, 1, 0, 1, 20),
args = list(scaled, convolved, obs, 0, 1, 1, 0, 1, 20),
target = c(0, 0, rep(1, 18)), dof = 0
)
})

test_that("calculate_secondary can calculate incidence using only current reports", {
check_equal(
args = list(reports, obs, 0.1, delay_pmf, 0, 0, 1, 1, 1, 20),
args = list(scaled, convolved, obs, 0, 0, 1, 1, 1, 20),
target = rep(1, 20), dof = 0
)
})

test_that("calculate_secondary can switch into prediction mode as expected", {
check_equal(
args = list(reports, obs, 0.1, delay_pmf, 1, 0, 1, 1, 1, 20),
args = list(scaled, convolved, obs, 1, 0, 1, 1, 1, 20),
target = c(1, rep(5, 19)), dof = 0
)
check_equal(
args = list(reports, obs, 0.1, delay_pmf, 1, 0, 1, 1, 1, 10),
args = list(scaled, convolved, obs, 1, 0, 1, 1, 1, 10),
target = c(1, rep(5, 9), 6:15), dof = 0
)
})

0 comments on commit e1ff1f5

Please sign in to comment.