From 50fc3cffaca0571b2852ae687436139e073a3a6c Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 15 Feb 2024 12:21:17 +0000 Subject: [PATCH] Add option to accumulate observations (#534) * add option to accumulate observations * accumulate in estimate_secondary model * add test for weekly accumulation * check there's data to fit initial growth model * ignore first observation when accumulating * document "na" argument * add news item * update obs_opts tests * make logical operator scalar * make NA option work with estimate_secondary * add tests * Apply suggestions from code review Co-authored-by: Sam Abbott --------- Co-authored-by: Sam Abbott --- NAMESPACE | 1 + NEWS.md | 1 + R/create.R | 35 ++++++++++---- R/estimate_secondary.R | 17 +++++-- R/opts.R | 53 ++++++++++++++-------- inst/stan/data/observation_model.stan | 1 + inst/stan/estimate_infections.stan | 4 +- inst/stan/estimate_secondary.stan | 8 +++- inst/stan/functions/observation_model.stan | 40 ++++++++++++---- man/create_clean_reported_cases.Rd | 2 +- man/create_complete_cases.Rd | 25 ++++++++++ man/obs_opts.Rd | 19 ++++++-- tests/testthat/test-create_obs_model.R | 4 +- tests/testthat/test-estimate_infections.R | 8 ++++ tests/testthat/test-estimate_secondary.R | 53 ++++++++++++++++++++++ 15 files changed, 222 insertions(+), 49 deletions(-) create mode 100644 man/create_complete_cases.Rd diff --git a/NAMESPACE b/NAMESPACE index 9041b04eb..22c7377b8 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -133,6 +133,7 @@ importFrom(data.table,fwrite) importFrom(data.table,getDTthreads) importFrom(data.table,melt) importFrom(data.table,merge.data.table) +importFrom(data.table,nafill) importFrom(data.table,rbindlist) importFrom(data.table,setDT) importFrom(data.table,setDTthreads) diff --git a/NEWS.md b/NEWS.md index c006184ca..e68a8f9e5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -26,6 +26,7 @@ ## Model changes * Updated the parameterisation of the dispersion term `phi` to be `phi = 1 / sqrt_phi ^ 2` rather than the previous parameterisation `phi = 1 / sqrt(sqrt_phi)` based on the suggested prior [here](https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations#story-when-the-generic-prior-fails-the-case-of-the-negative-binomial) and the performance benefits seen in the `epinowcast` package (see [here](https://github.com/epinowcast/epinowcast/blob/8eff560d1fd8305f5fb26c21324b2bfca1f002b4/inst/stan/epinowcast.stan#L314)). By @seabbs in # and reviewed by @sbfnk. +* Added an `na` argument to `obs_opts()` that allows the user to specify whether NA values in the data should be interpreted as missing or accumulated in the next non-NA data point. By @sbfnk in #534 and reviewed by @seabbs. # EpiNow2 1.4.0 diff --git a/R/create.R b/R/create.R index 7846bc600..f58ad80d6 100644 --- a/R/create.R +++ b/R/create.R @@ -26,7 +26,7 @@ #' @export #' @examples #' create_clean_reported_cases(example_confirmed, 7) -create_clean_reported_cases <- function(reported_cases, horizon, +create_clean_reported_cases <- function(reported_cases, horizon = 0, filter_leading_zeros = TRUE, zero_threshold = Inf, fill = NA_integer_) { @@ -75,6 +75,25 @@ create_clean_reported_cases <- function(reported_cases, horizon, return(reported_cases) } +#' Create complete cases +#' @description `r lifecycle::badge("stable")` +#' Creates a complete data set without NA values and appropriate indices +#' +#' @param cases; data frame with a column "confirm" that may contain NA values +#' @param burn_in; integer (default 0). Number of days to remove from the +#' start of the time series be filtered out. +#' +#' @return A data frame without NA values, with two columns: confirm (number) +#' @author Sebastian Funk +#' @importFrom data.table setDT +#' @keywords internal +create_complete_cases <- function(cases) { + cases <- setDT(cases) + cases[, lookup := seq_len(.N)] + cases <- cases[!is.na(cases$confirm)] + return(cases[]) +} + #' Create Delay Shifted Cases #' #' @description `r lifecycle::badge("stable")` @@ -397,6 +416,7 @@ create_obs_model <- function(obs = obs_opts(), dates) { week_effect = ifelse(obs$week_effect, obs$week_length, 1), obs_weight = obs$weight, obs_scale = as.numeric(length(obs$scale) != 0), + accumulate = obs$accumulate, likelihood = as.numeric(obs$likelihood), return_likelihood = as.numeric(obs$return_likelihood) ) @@ -447,16 +467,13 @@ create_stan_data <- function(reported_cases, seeding_time, backcalc, shifted_cases) { cases <- reported_cases[(seeding_time + 1):(.N - horizon)] - cases[, lookup := seq_len(.N)] - complete_cases <- cases[!is.na(cases$confirm)] - cases_time <- complete_cases$lookup - complete_cases <- complete_cases$confirm + complete_cases <- create_complete_cases(cases) cases <- cases$confirm data <- list( - cases = complete_cases, - cases_time = cases_time, - lt = length(cases_time), + cases = complete_cases$confirm, + cases_time = complete_cases$lookup, + lt = nrow(complete_cases), shifted_cases = shifted_cases, t = length(reported_cases$date), horizon = horizon, @@ -481,7 +498,7 @@ create_stan_data <- function(reported_cases, seeding_time, is.na(data$prior_infections) || is.null(data$prior_infections), 0, data$prior_infections ) - if (data$seeding_time > 1) { + if (data$seeding_time > 1 && nrow(first_week) > 1) { safe_lm <- purrr::safely(stats::lm) data$prior_growth <- safe_lm(log(confirm) ~ t, data = first_week)[[1]] data$prior_growth <- ifelse(is.null(data$prior_growth), 0, diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index f0511337b..b31b418c5 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -64,7 +64,7 @@ #' @inheritParams calc_CrIs #' @importFrom rstan sampling #' @importFrom lubridate wday -#' @importFrom data.table as.data.table merge.data.table +#' @importFrom data.table as.data.table merge.data.table nafill #' @importFrom utils modifyList #' @importFrom checkmate assert_class assert_numeric assert_data_frame #' assert_logical @@ -166,6 +166,15 @@ estimate_secondary <- function(reports, assert_logical(verbose) reports <- data.table::as.data.table(reports) + secondary_reports <- reports[, list(date, confirm = secondary)] + secondary_reports <- create_clean_reported_cases(secondary_reports) + ## fill in missing data (required if fitting to prevalence) + complete_secondary <- create_complete_cases(secondary_reports) + + ## fill down + secondary_reports[, confirm := nafill(confirm, type = "locf")] + ## fill any early data up + secondary_reports[, confirm := nafill(confirm, type = "nocb")] if (burn_in >= nrow(reports)) { stop("burn_in is greater or equal to the number of observations. @@ -174,8 +183,10 @@ estimate_secondary <- function(reports, # observation and control data data <- list( t = nrow(reports), - obs = reports$secondary, primary = reports$primary, + obs = secondary_reports$confirm, + obs_time = complete_secondary[lookup > burn_in]$lookup - burn_in, + lt = sum(complete_secondary$lookup > burn_in), burn_in = burn_in, seeding_time = 0 ) @@ -391,7 +402,7 @@ plot.estimate_secondary <- function(x, primary = FALSE, from = NULL, to = NULL, new_obs = NULL, ...) { - predictions <- data.table::copy(x$predictions) + predictions <- data.table::copy(x$predictions)[!is.na(secondary)] if (!is.null(new_obs)) { new_obs <- data.table::as.data.table(new_obs) diff --git a/R/opts.R b/R/opts.R index e5f58c5b3..2ad5f76a8 100644 --- a/R/opts.R +++ b/R/opts.R @@ -427,32 +427,35 @@ gp_opts <- function(basis_prop = 0.2, #' Defines a list specifying the structure of the observation #' model. Custom settings can be supplied which override the defaults. #' @param family Character string defining the observation model. Options are -#' Negative binomial ("negbin"), the default, and Poisson. -#' @param phi A numeric vector of length 2, defaults to 0, 1. Indicates the -#' mean and standard deviation of the normal prior used for the observation -#' process. -#' -#' @param weight Numeric, defaults to 1. Weight to give the observed data in -#' the log density. +#' Negative binomial ("negbin"), the default, and Poisson. +#' @param phi A numeric vector of length 2, defaults to 0, 1. Indicates the mean +#' and standard deviation of the normal prior used for the observation +#' process. +#' @param weight Numeric, defaults to 1. Weight to give the observed data in the +#' log density. #' @param week_effect Logical defaulting to `TRUE`. Should a day of the week -#' effect be used in the observation model. -#' +#' effect be used in the observation model. #' @param week_length Numeric assumed length of the week in days, defaulting to #' 7 days. This can be modified if data aggregated over a period other than a #' week or if data has a non-weekly periodicity. -#' #' @param scale List, defaulting to an empty list. Should an scaling factor be -#' applied to map latent infections (convolved to date of report). If none -#' empty a mean (`mean`) and standard deviation (`sd`) needs to be supplied -#' defining the normally distributed scaling factor. -#' +#' applied to map latent infections (convolved to date of report). If none +#' empty a mean (`mean`) and standard deviation (`sd`) needs to be supplied +#' defining the normally distributed scaling factor. +#' @param na Character. Options are "missing" (the default) and "accumulate". +#' This determines how NA values in the data are interpreted. If set to +#' "missing", any NA values in the observation data set will be interpreted as +#' missing and skipped in the likelihood. If set to "accumulate", modelled +#' observations will be accumulated and added to the next non-NA data point. +#' This can be used to model incidence data that is reported at less than +#' daily intervals. If set to "accumulate", the first data point is not +#' included in the likelihood but used only to reset modelled observations to +#' zero. #' @param likelihood Logical, defaults to `TRUE`. Should the likelihood be -#' included in the model. -#' +#' included in the model. #' @param return_likelihood Logical, defaults to `FALSE`. Should the likelihood -#' be returned by the model. +#' be returned by the model. #' @importFrom rlang arg_match -#' #' @return An `` object of observation model settings. #' @author Sam Abbott #' @export @@ -471,11 +474,24 @@ obs_opts <- function(family = "negbin", week_effect = TRUE, week_length = 7, scale = list(), + na = c("missing", "accumulate"), likelihood = TRUE, return_likelihood = FALSE) { if (length(phi) != 2 || !is.numeric(phi)) { stop("phi be numeric and of length two") } + na <- arg_match(na) + if (na == "accumulate") { + message( + "Accumulating modelled values that correspond to NA values in the data ", + "by adding them to the next non-NA data point. This means that the ", + "first data point is not included in the likelihood but used only to ", + "reset modelled observations to zero. If the first data point should be ", + "included in the likelihood this can be achieved by adding a data point ", + "of arbitrary value before the first data point." + ) + } + obs <- list( family = arg_match(family, values = c("poisson", "negbin")), phi = phi, @@ -483,6 +499,7 @@ obs_opts <- function(family = "negbin", week_effect = week_effect, week_length = week_length, scale = scale, + accumulate = as.integer(na == "accumulate"), likelihood = likelihood, return_likelihood = return_likelihood ) diff --git a/inst/stan/data/observation_model.stan b/inst/stan/data/observation_model.stan index 0ce9ef3bb..671004ef4 100644 --- a/inst/stan/data/observation_model.stan +++ b/inst/stan/data/observation_model.stan @@ -9,5 +9,6 @@ real obs_weight; // weight given to observation in log density int likelihood; // Should the likelihood be included in the model int return_likelihood; // Should the likehood be returned by the model + int accumulate; // Should missing values be accumulated int trunc_id; // id of truncation int delay_id; // id of delay diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index c1ac8c63e..7eb128b2d 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -148,8 +148,8 @@ model { // observed reports from mean of reports (update likelihood) if (likelihood) { report_lp( - cases, obs_reports[cases_time], rep_phi, phi_mean, phi_sd, model_type, - obs_weight + cases, cases_time, obs_reports, rep_phi, phi_mean, phi_sd, model_type, + obs_weight, accumulate ); } } diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index e2209349b..77507a9ed 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -8,7 +8,9 @@ functions { data { int t; // time of observations + int lt; // time of observations array[t] int obs; // observed secondary data + array[lt] int obs_time; // observed secondary data vector[t] primary; // observed primary data int burn_in; // time period to not use for fitting #include data/secondary.stan @@ -83,8 +85,10 @@ model { } // observed secondary reports from mean of secondary reports (update likelihood) if (likelihood) { - report_lp(obs[(burn_in + 1):t], secondary[(burn_in + 1):t], - rep_phi, phi_mean, phi_sd, model_type, 1); + report_lp( + obs[(burn_in + 1):t][obs_time], obs_time, secondary[(burn_in + 1):t], + rep_phi, phi_mean, phi_sd, model_type, 1, accumulate + ); } } diff --git a/inst/stan/functions/observation_model.stan b/inst/stan/functions/observation_model.stan index ed3364ae6..6535c07f7 100644 --- a/inst/stan/functions/observation_model.stan +++ b/inst/stan/functions/observation_model.stan @@ -51,22 +51,46 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd, } } // update log density for reported cases -void report_lp(array[] int cases, vector reports, +void report_lp(array[] int cases, array[] int cases_time, vector reports, array[] real rep_phi, real phi_mean, real phi_sd, - int model_type, real weight) { + int model_type, real weight, int accumulate) { + int n = num_elements(cases_time) - accumulate; // number of observations + vector[n] obs_reports; // reports at observation time + array[n] int obs_cases; // observed cases at observation time + if (accumulate) { + int t = num_elements(reports); + int i = 0; + int current_obs = 0; + obs_reports = rep_vector(0, n); + while (i <= t && current_obs <= n) { + if (current_obs > 0) { // first observation gets ignored when accumulating + obs_reports[current_obs] += reports[i]; + } + if (i == cases_time[current_obs + 1]) { + current_obs += 1; + } + i += 1; + } + obs_cases = cases[2:(n + 1)]; + } else { + obs_reports = reports[cases_time]; + obs_cases = cases; + } if (model_type) { - real dispersion = 1 / pow(rep_phi[model_type], 2); + real dispersion = 1 / pow(rep_phi[model_type], 2); rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,]; if (weight == 1) { - cases ~ neg_binomial_2(reports, dispersion); + obs_cases ~ neg_binomial_2(obs_reports, dispersion); } else { - target += neg_binomial_2_lpmf(cases | reports, dispersion) * weight; + target += neg_binomial_2_lpmf( + obs_cases | obs_reports, dispersion + ) * weight; } } else { if (weight == 1) { - cases ~ poisson(reports); + obs_cases ~ poisson(obs_reports); } else { - target += poisson_lpmf(cases | reports) * weight; + target += poisson_lpmf(obs_cases | obs_reports) * weight; } } } @@ -97,7 +121,7 @@ array[] int report_rng(vector reports, array[] real rep_phi, int model_type) { if (model_type) { dispersion = 1 / pow(rep_phi[model_type], 2); } - + for (s in 1:t) { if (reports[s] < 1e-8) { sampled_reports[s] = 0; diff --git a/man/create_clean_reported_cases.Rd b/man/create_clean_reported_cases.Rd index c53830c0c..5daf45877 100644 --- a/man/create_clean_reported_cases.Rd +++ b/man/create_clean_reported_cases.Rd @@ -6,7 +6,7 @@ \usage{ create_clean_reported_cases( reported_cases, - horizon, + horizon = 0, filter_leading_zeros = TRUE, zero_threshold = Inf, fill = NA_integer_ diff --git a/man/create_complete_cases.Rd b/man/create_complete_cases.Rd new file mode 100644 index 000000000..79eb108c6 --- /dev/null +++ b/man/create_complete_cases.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/create.R +\name{create_complete_cases} +\alias{create_complete_cases} +\title{Create complete cases} +\usage{ +create_complete_cases(cases) +} +\arguments{ +\item{cases;}{data frame with a column "confirm" that may contain NA values} + +\item{burn_in;}{integer (default 0). Number of days to remove from the +start of the time series be filtered out.} +} +\value{ +A data frame without NA values, with two columns: confirm (number) +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +Creates a complete data set without NA values and appropriate indices +} +\author{ +Sebastian Funk +} +\keyword{internal} diff --git a/man/obs_opts.Rd b/man/obs_opts.Rd index f9c617dd5..85aecaa7b 100644 --- a/man/obs_opts.Rd +++ b/man/obs_opts.Rd @@ -11,6 +11,7 @@ obs_opts( week_effect = TRUE, week_length = 7, scale = list(), + na = c("missing", "accumulate"), likelihood = TRUE, return_likelihood = FALSE ) @@ -19,12 +20,12 @@ obs_opts( \item{family}{Character string defining the observation model. Options are Negative binomial ("negbin"), the default, and Poisson.} -\item{phi}{A numeric vector of length 2, defaults to 0, 1. Indicates the -mean and standard deviation of the normal prior used for the observation +\item{phi}{A numeric vector of length 2, defaults to 0, 1. Indicates the mean +and standard deviation of the normal prior used for the observation process.} -\item{weight}{Numeric, defaults to 1. Weight to give the observed data in -the log density.} +\item{weight}{Numeric, defaults to 1. Weight to give the observed data in the +log density.} \item{week_effect}{Logical defaulting to \code{TRUE}. Should a day of the week effect be used in the observation model.} @@ -38,6 +39,16 @@ applied to map latent infections (convolved to date of report). If none empty a mean (\code{mean}) and standard deviation (\code{sd}) needs to be supplied defining the normally distributed scaling factor.} +\item{na}{Character. Options are "missing" (the default) and "accumulate". +This determines how NA values in the data are interpreted. If set to +"missing", any NA values in the observation data set will be interpreted as +missing and skipped in the likelihood. If set to "accumulate", modelled +observations will be accumulated and added to the next non-NA data point. +This can be used to model incidence data that is reported at less than +daily intervals. If set to "accumulate", the first data point is not +included in the likelihood but used only to reset modelled observations to +zero.} + \item{likelihood}{Logical, defaults to \code{TRUE}. Should the likelihood be included in the model.} diff --git a/tests/testthat/test-create_obs_model.R b/tests/testthat/test-create_obs_model.R index 4829bb0af..4211c7dbe 100644 --- a/tests/testthat/test-create_obs_model.R +++ b/tests/testthat/test-create_obs_model.R @@ -3,10 +3,10 @@ dates <- seq(as.Date("2020-03-15"), by = "days", length.out = 15) test_that("create_obs_model works with default settings", { obs <- create_obs_model(dates = dates) - expect_equal(length(obs), 11) + expect_equal(length(obs), 12) expect_equal(names(obs), c( "model_type", "phi_mean", "phi_sd", "week_effect", "obs_weight", - "obs_scale", "likelihood", "return_likelihood", + "obs_scale", "accumulate", "likelihood", "return_likelihood", "day_of_week", "obs_scale_mean", "obs_scale_sd" )) diff --git a/tests/testthat/test-estimate_infections.R b/tests/testthat/test-estimate_infections.R index 6a90bb98f..53234b0ff 100644 --- a/tests/testthat/test-estimate_infections.R +++ b/tests/testthat/test-estimate_infections.R @@ -43,6 +43,14 @@ test_that("estimate_infections successfully returns estimates when passed NA val test_estimate_infections(reported_cases_na) }) +test_that("estimate_infections successfully returns estimates when accumulating to weekly", { + skip_on_cran() + reported_cases_weekly <- data.table::copy(reported_cases) + reported_cases_weekly[, confirm := frollsum(confirm, 7)] + reported_cases_weekly <- + reported_cases_weekly[seq(7, nrow(reported_cases_weekly), 7)] + test_estimate_infections(reported_cases_weekly, obs = obs_opts(na = "accumulate")) +}) test_that("estimate_infections successfully returns estimates using no delays", { skip_on_cran() diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index 5485f309b..f252bc0f0 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -76,6 +76,59 @@ test_that("estimate_secondary can return values from simulated data and plot expect_error(plot(inc, primary = TRUE), NA) }) +test_that("estimate_secondary successfully returns estimates when passed NA values", { + skip_on_cran() + cases_na <- data.table::copy(inc_cases) + cases_na[sample(1:60, 5), secondary := NA] + inc_na <- estimate_secondary(cases_na[1:60], + delays = delay_opts( + dist_spec( + mean = 1.8, mean_sd = 0, + sd = 0.5, sd_sd = 0, max = 30 + ) + ), + obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + verbose = FALSE + ) + prev_cases_na <- data.table::copy(prev_cases) + prev_cases_na[sample(1:60, 5), secondary := NA] + prev_na <- estimate_secondary(prev_cases_na[1:60], + secondary = secondary_opts(type = "prevalence"), + delays = delay_opts( + dist_spec( + mean = 1.8, mean_sd = 0, + sd = 0.5, sd_sd = 0, max = 30 + ) + ), + obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + verbose = FALSE + ) + expect_true(is.list(inc_na$data)) + expect_true(is.list(prev_na$data)) +}) + +test_that("estimate_secondary successfully returns estimates when accumulating to weekly", { + skip_on_cran() + secondary_weekly <- inc_cases[, list(date, secondary)] + secondary_weekly[, secondary := frollsum(secondary, 7)] + secondary_weekly <- secondary_weekly[seq(7, nrow(secondary_weekly), by = 7)] + cases_weekly <- merge( + cases[, list(date, primary)], secondary_weekly, by = "date", all.x = TRUE + ) + inc_weekly <- estimate_secondary(cases_weekly, + delays = delay_opts( + dist_spec( + mean = 1.8, mean_sd = 0, + sd = 0.5, sd_sd = 0, max = 30 + ) + ), + obs = obs_opts( + scale = list(mean = 0.4, sd = 0.05), week_effect = FALSE, na = "accumulate" + ), verbose = FALSE + ) + expect_true(is.list(inc_weekly$data)) +}) + test_that("estimate_secondary can recover simulated parameters", { expect_equal( inc_posterior[, mean], c(1.8, 0.5, 0.4),