Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise convolutions #745

Merged
merged 10 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- The interface for defining delay distributions has been generalised to also cater for continuous distributions
- When defining probability distributions these can now be truncated using the `tolerance` argument
- Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @.
- Optimised convolution code to take into account the relative length of the vectors being convolved. See #745 by @seabbs and reviewed by @jamesmbaazam.
- Switch to broadcasting the day of the week effect. By @seabbs in #746 and reviewed by @jamesmbaazam.
- A warning is now thrown if nonparametric PMFs passed to delay options have consecutive tail values that are below a certain low threshold as these lead to loss in speed with little gain in accuracy. By @jamesmbaazam in #752 and reviewed by @seabbs.

Expand Down
114 changes: 89 additions & 25 deletions inst/stan/functions/convolve.stan
Original file line number Diff line number Diff line change
@@ -1,36 +1,100 @@
// convolve two vectors as a backwards dot product
// y vector should be reversed
// limited to the length of x and backwards looking for x indexes
/**
* Calculate convolution indices for the case where s <= xlen
*
* @param s Current position in the output vector
* @param xlen Length of the x vector
* @param ylen Length of the y vector
* @return An array of integers: {start_x, end_x, start_y, end_y}
*/
array[] int calc_conv_indices_xlen(int s, int xlen, int ylen) {
int s_minus_ylen = s - ylen;
int start_x = max(1, s_minus_ylen + 1);
int end_x = s;
int start_y = max(1, 1 - s_minus_ylen);
int end_y = ylen;
return {start_x, end_x, start_y, end_y};
seabbs marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Calculate convolution indices for the case where s > xlen
*
* @param s Current position in the output vector
* @param xlen Length of the x vector
* @param ylen Length of the y vector
* @return An array of integers: {start_x, end_x, start_y, end_y}
*/
array[] int calc_conv_indices_len(int s, int xlen, int ylen) {
int s_minus_ylen = s - ylen;
int start_x = max(1, s_minus_ylen + 1);
int end_x = xlen;
int start_y = max(1, 1 - s_minus_ylen);;
int end_y = ylen + xlen - s;
return {start_x, end_x, start_y, end_y};
}

/**
* Convolve a vector with a reversed probability mass function.
*
* This function performs a discrete convolution of two vectors, where the second vector
* is assumed to be an already reversed probability mass function.
*
* @param x The input vector to be convolved.
* @param y The already reversed probability mass function vector.
* @param len The desired length of the output vector.
* @return A vector of length `len` containing the convolution result.
* @throws If `len` is not of equal length to the sum of the lengths of `x` and `y`.
*/
vector convolve_with_rev_pmf(vector x, vector y, int len) {
int xlen = num_elements(x);
int ylen = num_elements(y);
vector[len] z;
if (xlen + ylen <= len) {
reject("convolve_with_rev_pmf: len is longer then x and y combined");
}
for (s in 1:len) {
z[s] = dot_product(
x[max(1, (s - ylen + 1)):min(s, xlen)],
y[max(1, ylen - s + 1):min(ylen, ylen + xlen - s)]
);
int xlen = num_elements(x);
int ylen = num_elements(y);
vector[len] z;

if (xlen + ylen - 1 < len) {
reject("convolve_with_rev_pmf: len is longer than x and y convolved");
}

if (xlen > len) {
reject("convolve_with_rev_pmf: len is shorter than x");
}

for (s in 1:xlen) {
array[4] int indices = calc_conv_indices_xlen(s, xlen, ylen);
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
}

if (len > xlen) {
for (s in (xlen + 1):len) {
array[4] int indices = calc_conv_indices_len(s, xlen, ylen);
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
}
seabbs marked this conversation as resolved.
Show resolved Hide resolved
return(z);
}

return z;
}


// convolve latent infections to reported (but still unobserved) cases
/**
* Convolve infections to reported cases.
*
* This function convolves a vector of infections with a reversed delay
* distribution to produce a vector of reported cases.
*
* @param infections A vector of infection counts.
* @param delay_rev_pmf A vector representing the reversed probability mass
* function of the delay distribution.
* @param seeding_time The number of initial time steps to exclude from the
* output.
* @return A vector of reported cases, starting from `seeding_time + 1`.
*/
vector convolve_to_report(vector infections,
vector delay_rev_pmf,
int seeding_time) {
int t = num_elements(infections);
vector[t - seeding_time] reports;
vector[t] unobs_reports = infections;
int delays = num_elements(delay_rev_pmf);
if (delays) {
unobs_reports = convolve_with_rev_pmf(unobs_reports, delay_rev_pmf, t);
reports = unobs_reports[(seeding_time + 1):t];
} else {
reports = infections[(seeding_time + 1):t];

if (delays == 0) {
return infections[(seeding_time + 1):t];
}
return(reports);

vector[t] unobs_reports = convolve_with_rev_pmf(infections, delay_rev_pmf, t);
return unobs_reports[(seeding_time + 1):t];
}
6 changes: 3 additions & 3 deletions inst/stan/functions/delays.stan
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ vector get_delay_rev_pmf(
pmf[1:new_len] = new_variable_pmf;
} else { // subsequent delay to be convolved
pmf[1:new_len] = convolve_with_rev_pmf(
pmf[1:current_len], reverse_mf(new_variable_pmf), new_len
pmf[1:current_len], reverse(new_variable_pmf), new_len
);
}
} else { // nonparametric
Expand All @@ -54,7 +54,7 @@ vector get_delay_rev_pmf(
pmf[1:new_len] = delay_np_pmf[start:end];
} else { // subsequent delay to be convolved
pmf[1:new_len] = convolve_with_rev_pmf(
pmf[1:current_len], reverse_mf(delay_np_pmf[start:end]), new_len
pmf[1:current_len], reverse(delay_np_pmf[start:end]), new_len
);
}
}
Expand All @@ -70,7 +70,7 @@ vector get_delay_rev_pmf(
pmf = cumulative_sum(pmf);
}
if (reverse_pmf) {
pmf = reverse_mf(pmf);
pmf = reverse(pmf);
}
return pmf;
}
Expand Down
33 changes: 0 additions & 33 deletions inst/stan/functions/pmfs.stan
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,3 @@ vector discretised_pmf(vector params, int n, int dist) {
}
return(exp(lpmf));
}

// reverse a mf
vector reverse_mf(vector pmf) {
int pmf_length = num_elements(pmf);
vector[pmf_length] rev_pmf;
for (d in 1:pmf_length) {
rev_pmf[d] = pmf[pmf_length - d + 1];
}
return rev_pmf;
}

vector rev_seq(int base, int len) {
vector[len] seq;
for (i in 1:len) {
seq[i] = len + base - i;
}
return(seq);
}

real rev_pmf_mean(vector rev_pmf, int base) {
int len = num_elements(rev_pmf);
vector[len] rev_pmf_seq = rev_seq(base, len);
return(dot_product(rev_pmf_seq, rev_pmf));
}

real rev_pmf_var(vector rev_pmf, int base, real mean) {
int len = num_elements(rev_pmf);
vector[len] rev_pmf_seq = rev_seq(base, len);
for (i in 1:len) {
rev_pmf_seq[i] = pow(rev_pmf_seq[i], 2);
}
return(dot_product(rev_pmf_seq, rev_pmf) - pow(mean, 2));
}
29 changes: 26 additions & 3 deletions tests/testthat/test-stan-convole.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
skip_on_cran()
skip_on_os("windows")

test_that("convolve can combine two pmfs as expected", {
# Test calc_conv_indices_xlen function
test_that("calc_conv_indices_xlen calculates correct indices", {
expect_equal(calc_conv_indices_xlen(1, 5, 3), c(1, 1, 3, 3))
expect_equal(calc_conv_indices_xlen(3, 5, 3), c(1, 3, 1, 3))
expect_equal(calc_conv_indices_xlen(5, 5, 3), c(3, 5, 1, 3))
})

# Test calc_conv_indices_len function
test_that("calc_conv_indices_len calculates correct indices", {
expect_equal(calc_conv_indices_len(6, 5, 3), c(4, 5, 1, 2))
expect_equal(calc_conv_indices_len(7, 5, 3), c(5, 5, 1, 1))
expect_equal(calc_conv_indices_len(8, 5, 3), c(6, 5, 1, 0))
})

test_that("convolve_with_rev_pmf can combine two pmfs as expected", {
expect_equal(
convolve_with_rev_pmf(c(0.1, 0.2, 0.7), rev(c(0.1, 0.2, 0.7)), 5),
c(0.01, 0.04, 0.18, 0.28, 0.49),
Expand All @@ -14,7 +28,7 @@ test_that("convolve can combine two pmfs as expected", {
)
})

test_that("convolve performs the same as a numerical convolution", {
test_that("convolve_with_rev_pmf performs the same as a numerical convolution", {
# Sample and analytical PMFs for two Poisson distributions
x <- rpois(10000, 3)
xpmf <- dpois(0:20, 3)
Expand All @@ -32,7 +46,7 @@ test_that("convolve performs the same as a numerical convolution", {
expect_lte(sum(abs(conv_cdf - cdf)), 0.1)
})

test_that("convolve_dot_product can combine vectors as we expect", {
test_that("convolve_with_rev_pmf can combine vectors as we expect", {
expect_equal(
convolve_with_rev_pmf(c(0.1, 0.2, 0.7), rev(c(0.1, 0.2, 0.7)), 3),
c(0.01, 0.04, 0.18),
Expand All @@ -54,3 +68,12 @@ test_that("convolve_dot_product can combine vectors as we expect", {
x
)
})

test_that("convolve_dot_product can combine two vectors where x > y and len = x", {
x <- c(1, 2, 3, 4, 5)
y <- c(1, 2, 3)
expect_equal(
convolve_with_rev_pmf(x, rev(y), 5),
c(1, 4, 10, 16, 22)
)
})
2 changes: 1 addition & 1 deletion tests/testthat/test-stan-secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ skip_on_os("windows")
# test primary reports and observations
reports <- rep(10, 20)
obs <- rep(4, 20)
delay_rev_pmf <- reverse_mf(discretised_pmf(c(log(3), 0.1), 5, 0))
delay_rev_pmf <- rev(discretised_pmf(c(log(3), 0.1), 5, 0))
scaled <- reports * 0.1
convolved <- rep(1e-5, 20) + convolve_to_report(scaled, delay_rev_pmf, 0)

Expand Down
Loading