diff --git a/R/marginalisers.R b/R/marginalisers.R index 40d4a161..c5b3f3a9 100644 --- a/R/marginalisers.R +++ b/R/marginalisers.R @@ -354,9 +354,19 @@ laplace_approximation <- function(tolerance = 1e-6, w <- -d2 rw <- sqrt(w) - # approximate posterior covariance & cholesky factor + # approximate posterior covariance + # do we need the eye? mat1 <- tf$matmul(rw, tf_transpose(rw)) * sigma + eye - u <- tf$cholesky(mat1) + l <- tf$cholesky(mat1) + v <- tf$linalg$triangular_solve(matrix = l, + rhs = sigma * rw, + lower = TRUE, + adjoint = TRUE) + covar <- sigma - tf$linalg$matmul(v, v, transpose_b = TRUE) + + # log-determinant of l + l_diag <- tf$matrix_diag_part(l) + logdet <- tf_sum(tf$log(l_diag)) # convergence information iter <- out[[7]] @@ -366,7 +376,8 @@ laplace_approximation <- function(tolerance = 1e-6, list(z = z, mu = mu, a = a, - u = u, + logdet = logdet, + covar = covar, iterations = iter, converged = converged) @@ -418,11 +429,16 @@ laplace_approximation <- function(tolerance = 1e-6, tf_operation = "get_element", operation_args = list("mu")) - u <- op("chol_sigma", - parameter_list, - dim = dim(sigma), - tf_operation = "get_element", - operation_args = list("u")) + logdet <- op("log determinant", + parameter_list, + tf_operation = "get_element", + operation_args = list("logdet")) + + covar <- op("covar", + parameter_list, + dim = dim(sigma), + tf_operation = "get_element", + operation_args = list("covar")) iterations <- op("iterations", parameter_list, @@ -438,7 +454,8 @@ laplace_approximation <- function(tolerance = 1e-6, list(z = z, a = a, mu = mu, - u = u, + logdet = logdet, + covar = covar, iterations = iterations, converged = converged) @@ -467,14 +484,12 @@ laplace_approximation <- function(tolerance = 1e-6, } mu <- parameters$mu - u <- parameters$u + logdet <- parameters$logdet z <- parameters$z a <- parameters$a # the approximate marginal conditional posterior - u_diag <- tf$matrix_diag_part(u) - logdet <- tf_sum(tf$log(u_diag)) - nmcp <- psi(a, z, mu) + tf$squeeze(logdet, 1) + nmcp <- psi(a, z, mu) + tf$squeeze(u_logdet, 1) -nmcp @@ -483,7 +498,7 @@ laplace_approximation <- function(tolerance = 1e-6, return_list_function <- function(parameters) { list(mean = t(parameters$z), - sigma = chol2symm(parameters$u), + sigma = parameters$covar, iterations = parameters$iterations, converged = parameters$converged) diff --git a/tests/testthat/test_marginalisation.R b/tests/testthat/test_marginalisation.R index 04f43382..72f70782 100644 --- a/tests/testthat/test_marginalisation.R +++ b/tests/testthat/test_marginalisation.R @@ -243,14 +243,12 @@ test_that("laplace approximation converges on correct posterior", { theta_mu <- (y * obs_prec + mu * prec) * theta_var theta_sd <- sqrt(theta_var) - # Laplace solution: - lik <- function(theta) { - distribution(y) <- normal(t(theta), obs_sd) - } - # mock up as a multivariate normal distribution mean <- ones(1, 8) * mu sigma <- diag(8) * sd ^ 2 + lik <- function(theta) { + distribution(y) <- normal(t(theta), obs_sd) + } out <- marginalise(lik, multivariate_normal(mean, sigma), laplace_approximation(diagonal_hessian = TRUE)) @@ -264,6 +262,5 @@ test_that("laplace approximation converges on correct posterior", { # compare these to within a tolerance compare_op(analytic, laplace) - # modes are right, sds are not! })