Skip to content

Commit

Permalink
make laplace work for MVN test case
Browse files Browse the repository at this point in the history
  • Loading branch information
goldingn committed Feb 26, 2020
1 parent 2bae9ab commit ca8920b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 11 deletions.
14 changes: 5 additions & 9 deletions R/marginalisers.R
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ laplace_approximation <- function(tolerance = 1e-6,

a <- out[[2]]

# apparently we need to redefine z and u here, or the backprop errors
# apparently we need to redefine z etc. here, or the backprop errors

# lots of duplicated code; this could be tidied up, but I ran out of time!
z <- tf$matmul(sigma, a) + mu
Expand All @@ -353,18 +353,14 @@ laplace_approximation <- function(tolerance = 1e-6,
d2 <- deriv[[2]]
w <- -d2
rw <- sqrt(w)
hessian <- tf$linalg$diag(tf$squeeze(w, 2L))

# approximate posterior covariance
# do we need the eye?
mat1 <- tf$matmul(rw, tf_transpose(rw)) * sigma + eye
l <- tf$cholesky(mat1)
v <- tf$linalg$triangular_solve(matrix = l,
rhs = sigma * rw,
lower = TRUE,
adjoint = TRUE)
covar <- sigma - tf$linalg$matmul(v, v, transpose_b = TRUE)
covar <- tf$linalg$inv(tf$linalg$inv(sigma) + hessian)

# log-determinant of l
mat1 <- tf$matmul(rw, tf_transpose(rw)) * sigma + eye
l <- tf$cholesky(mat1)
l_diag <- tf$matrix_diag_part(l)
logdet <- tf_sum(tf$log(l_diag))

Expand Down
60 changes: 58 additions & 2 deletions tests/testthat/test_marginalisation.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ test_that("inference runs with laplace approximation", {

})

test_that("laplace approximation converges on correct posterior", {
test_that("laplace approximation has correct posterior for univariate normal", {

skip_if_not(check_tf_version())
source("helpers.R")
Expand All @@ -224,7 +224,6 @@ test_that("laplace approximation converges on correct posterior", {
# analytic solution:

# nolint start
# Marginalising int we can write:
# Bayes theorum gives:
# p(theta | y) \propto p(y|theta) p(theta)
# which with normal densities is:
Expand Down Expand Up @@ -262,5 +261,62 @@ test_that("laplace approximation converges on correct posterior", {
# compare these to within a tolerance
compare_op(analytic, laplace)

})

test_that("laplace approximation has correct posterior for multivariate normal", {

skip_if_not(check_tf_version())
source("helpers.R")

# nolint start
# test vs analytic posterior on 8 schools data with no pooling:
# y_i ~ N(theta_i, obs_sd_i ^ 2)
# theta ~ MVN(mu, sigma)
# the posterior for theta is multivariate normal, so laplace should be exact
# nolint end

# eight schools data
y <- c(28.39, 7.94, -2.75 , 6.82, -0.64, 0.63, 18.01, 12.16)
obs_sd <- c(14.9, 10.2, 16.3, 11.0, 9.4, 11.4, 10.4, 17.6)

# prior parameters for mu and sigma
mu <- rnorm(8)
sigma <- rwish(1, 9, diag(8))[1, , ]

# analytic solution:

# nolint start
# Bayes theorum gives:
# p(theta | y) \propto p(y|theta) p(theta)
# which with normal densities is:
# p(theta_i | y_i) \propto N(y_i | theta_i, obs_sd_i ^ 2) *
# N(theta_i | mu, sd ^ 2)
# which is equivalent to:
# p(theta | y) \propto MNN(theta_mu, theta_sigma)
# theta_sigma = (sigma^-1 + diag(1 /obs_sd^2))^-1
# theta_mu = theta_sigma (sigma^-1 mu + diag(1 /obs_sd^2) mean(y))^-1
# conjugate prior, see Wikipedia conjugate prior table
# nolint end

i_obs_sigma <- diag(1 / (obs_sd ^ 2))
i_sigma <- solve(sigma)
theta_sigma <- solve(i_sigma + i_obs_sigma)
theta_mu <- theta_sigma %*% (i_sigma %*% mu + i_obs_sigma %*% y)
theta_sigma_flat <- theta_sigma[upper.tri(theta_sigma, diag = TRUE)]

lik <- function(theta) {
distribution(y) <- normal(t(theta), obs_sd)
}
out <- marginalise(lik,
multivariate_normal(t(mu), sigma),
laplace_approximation(diagonal_hessian = TRUE, tolerance = 1e-32))
res <- calculate(mean = t(out$mean), sigma = out$sigma, iterations = out$iterations)
theta_mu_est <- res$mean
theta_sigma_est <- res$sigma
theta_sigma_est_flat <- theta_sigma_est[upper.tri(theta_sigma_est, diag = TRUE)]

# compare these to within a tolerance
compare_op(theta_mu, theta_mu_est)
compare_op(theta_sigma_flat, theta_sigma_est_flat)

})

0 comments on commit ca8920b

Please sign in to comment.