diff --git a/NEWS.md b/NEWS.md index 29b8a2ebf..a2df2e442 100644 --- a/NEWS.md +++ b/NEWS.md @@ -22,6 +22,7 @@ ## Package changes - The internal functions `create_clean_reported_cases()` has been broken up into several functions, with relevant ones `filter_leading_zeros()`, `add_breakpoints()` and `apply_zero_threshold()` exposed to the user. By @sbfnk in #884 and reviewed by @seabbs and @jamesmbaazam. +- The step of estimating early infections and growth in the internal function `create_stan_data()` has been separated into a new internal function `estimate_early_dynamics()`. By @jamesmbaazam in #888 and reviewed by @sbfnk. ## Documentation diff --git a/R/create.R b/R/create.R index ade71b2e4..789bbb444 100644 --- a/R/create.R +++ b/R/create.R @@ -444,6 +444,45 @@ create_obs_model <- function(obs = obs_opts(), dates) { return(data) } + +#' Calculate prior infections and fit early growth +#' +#' @description Calculates the prior infections and growth rate based on the +#' first week's data. +#' +#' @param cases Numeric vector; the case counts from the input data. +#' @inheritParams create_stan_data +#' @return A list containing `prior_infections` and `prior_growth`. +#' @keywords internal +estimate_early_dynamics <- function(cases, seeding_time) { + first_week <- data.table::data.table( + confirm = cases[seq_len(min(7, length(cases)))], + t = seq_len(min(7, length(cases))) + )[!is.na(confirm)] + + # Calculate prior infections + prior_infections <- log(mean(first_week$confirm, na.rm = TRUE)) + prior_infections <- ifelse( + is.na(prior_infections) || is.null(prior_infections), + 0, prior_infections + ) + + # Calculate prior growth + if (seeding_time > 1 && nrow(first_week) > 1) { + safe_lm <- purrr::safely(stats::lm) + prior_growth <- safe_lm(log(confirm) ~ t, data = first_week)[[1]] + prior_growth <- ifelse( + is.null(prior_growth), 0, prior_growth$coefficients[2] + ) + } else { + prior_growth <- 0 + } + return(list( + prior_infections = prior_infections, + prior_growth = prior_growth + )) +} + #' Create Stan Data Required for estimate_infections #' #' @description`r lifecycle::badge("stable")` @@ -501,28 +540,11 @@ create_stan_data <- function(data, seeding_time, delay = stan_data$seeding_time, horizon = stan_data$horizon ) ) - # initial estimate of growth - first_week <- data.table::data.table( - confirm = cases[seq_len(min(7, length(cases)))], - t = seq_len(min(7, length(cases))) - )[!is.na(confirm)] - stan_data$prior_infections <- log(mean(first_week$confirm, na.rm = TRUE)) - stan_data$prior_infections <- ifelse( - is.na(stan_data$prior_infections) || is.null(stan_data$prior_infections), - 0, stan_data$prior_infections + # calculate prior infections and fit early growth + stan_data <- c( + stan_data, + estimate_early_dynamics(cases, seeding_time) ) - if (stan_data$seeding_time > 1 && nrow(first_week) > 1) { - safe_lm <- purrr::safely(stats::lm) - stan_data$prior_growth <- safe_lm(log(confirm) ~ t, - data = first_week - )[[1]] - stan_data$prior_growth <- ifelse(is.null(stan_data$prior_growth), 0, - stan_data$prior_growth$coefficients[2] - ) - } else { - stan_data$prior_growth <- 0 - } - # backcalculation settings stan_data <- c(stan_data, create_backcalc_data(backcalc)) # gaussian process data diff --git a/man/estimate_early_dynamics.Rd b/man/estimate_early_dynamics.Rd new file mode 100644 index 000000000..399ff76b4 --- /dev/null +++ b/man/estimate_early_dynamics.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/create.R +\name{estimate_early_dynamics} +\alias{estimate_early_dynamics} +\title{Calculate prior infections and fit early growth} +\usage{ +estimate_early_dynamics(cases, seeding_time) +} +\arguments{ +\item{cases}{Numeric vector; the case counts from the input data.} + +\item{seeding_time}{Integer; seeding time, usually obtained using +\code{\link[=get_seeding_time]{get_seeding_time()}}.} +} +\value{ +A list containing \code{prior_infections} and \code{prior_growth}. +} +\description{ +Calculates the prior infections and growth rate based on the +first week's data. +} +\keyword{internal} diff --git a/tests/testthat/test-estimate-early-dynamics.R b/tests/testthat/test-estimate-early-dynamics.R new file mode 100644 index 000000000..b7c0d50bd --- /dev/null +++ b/tests/testthat/test-estimate-early-dynamics.R @@ -0,0 +1,49 @@ +test_that("estimate_early_dynamics works", { + cases <- EpiNow2::example_confirmed[1:30] + prior_estimates <- estimate_early_dynamics(cases$confirm, 7) + # Check dimensions + expect_identical( + names(prior_estimates), + c("prior_infections", "prior_growth") + ) + expect_identical(length(prior_estimates), 2L) + # Check values + expect_identical( + round(prior_estimates$prior_infections, 2), + 4.53 + ) + expect_identical( + round(prior_estimates$prior_growth, 2), + 0.35 + ) +}) + +test_that("estimate_early_dynamics handles NA values correctly", { + cases <- c(10, 20, NA, 40, 50, NA, 70) + prior_estimates <- estimate_early_dynamics(cases, 7) + expect_equal( + prior_estimates$prior_infections, + log(mean(c(10, 20, 40, 50, 70), na.rm = TRUE)) + ) + expect_true(!is.na(prior_estimates$prior_growth)) +}) + +test_that("estimate_early_dynamics handles exponential growth", { + cases <- 2^(c(0:6)) # Exponential growth + prior_estimates <- estimate_early_dynamics(cases, 7) + expect_equal(prior_estimates$prior_infections, log(mean(cases[1:7]))) + expect_true(prior_estimates$prior_growth > 0) # Growth should be positive +}) + +test_that("estimate_early_dynamics handles exponential decline", { + cases <- rev(2^(c(0:6))) # Exponential decline + prior_estimates <- estimate_early_dynamics(cases, 7) + expect_equal(prior_estimates$prior_infections, log(mean(cases[1:7]))) + expect_true(prior_estimates$prior_growth < 0) # Growth should be negative +}) + +test_that("estimate_early_dynamics correctly handles seeding time less than 2", { + cases <- c(5, 10, 20) # Less than 7 days of data + prior_estimates <- estimate_early_dynamics(cases, 1) + expect_equal(prior_estimates$prior_growth, 0) # Growth should be 0 if seeding time is <= 1 +}) \ No newline at end of file