Skip to content

Commit

Permalink
Merge pull request #577 from tlm-adjoint/jrmaddison/white_noise
Browse files Browse the repository at this point in the history
Firedrake backend: Apply the Riesz map when drawing white noise samples
  • Loading branch information
jrmaddison authored Jun 4, 2024
2 parents 46533c6 + f6bd081 commit 6c27141
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 35 deletions.
13 changes: 8 additions & 5 deletions tests/firedrake/test_block_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,17 +621,18 @@ def test_white_noise_sampler(setup_test, test_leaks,
trial = TrialFunction(space)

M = assemble(inner(trial, test) * dx)
V_ref = np.zeros((space_global_size(space), space_global_size(space)),
P_ref = np.zeros((space_global_size(space), space_global_size(space)),
dtype=space_dtype(space))
for i, j in itertools.product(range(space_global_size(space)),
range(space_global_size(space))):
V_ref[i, j] = M.petscmat[i, j]
P_ref[i, j] = M.petscmat[i, j]

rng = np.random.default_rng(
np.random.SeedSequence(entropy=np.random.get_state()[1][0]))
sampler = WhiteNoiseSampler(
space, rng, precondition=precondition,
solver_parameters={"mfn_tol": 1.0e-10})
mfn_solver_parameters={"mfn_tol": 1.0e-10},
ksp_solver_parameters=ls_parameters_cg)

V = np.zeros((space_global_size(space), space_global_size(space)),
dtype=space_dtype(space))
Expand All @@ -640,6 +641,8 @@ def test_white_noise_sampler(setup_test, test_leaks,
X = var_get_values(sampler.sample())
V += np.outer(X, X)
V /= N
error = abs(V - V_ref).max()
error = abs(V @ P_ref - np.eye(V.shape[0], dtype=V.dtype)).max()
print(f"{error=}")
assert abs(error) < 0.015
assert abs(error) < 0.17

del sampler._mfn
102 changes: 72 additions & 30 deletions tlm_adjoint/firedrake/block_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
"""

from .backend import (
TestFunction, TrialFunction, backend_assemble, backend_DirichletBC, dx,
inner)
LinearSolver as backend_LinearSolver, TestFunction, TrialFunction,
backend_assemble, backend_DirichletBC, dx, inner)
from ..interface import (
packed, space_dtype, space_eq, var_axpy, var_copy, var_dtype,
var_get_values, var_inner, var_local_size, var_new, var_set_values)
Expand All @@ -14,8 +14,10 @@
MatrixFreeMatrix, MixedSpace, NoneNullspace, Nullspace, TypedSpace)

from .backend_interface import assemble, matrix_multiply
from .variables import Constant, Function
from .parameters import copy_parameters
from .variables import Cofunction, Constant, Function

from functools import cached_property
import numpy as np
import ufl

Expand Down Expand Up @@ -279,7 +281,7 @@ class WhiteNoiseSampler:
.. math::
X = \Xi^{-T} \sqrt{ \Xi^T M \Xi } Z,
X = M^{-1} \Xi^{-T} \sqrt{ \Xi^T M \Xi } Z,
where
Expand All @@ -303,8 +305,11 @@ class WhiteNoiseSampler:
set equal to the identity.
M : :class:`firedrake.matrix.Matrix`
Mass matrix. Constructed by finite element assembly if not supplied.
solver_parameters : :class:`Mapping`
Solver parameters.
mfn_solver_parameters : :class:`Mapping`
:class:`slepc4py.SLEPc.MFN` solver parameters, used for the matrix
square root action.
ksp_solver_parameters : :class:`Mapping`
Solver parameters, used for :math:`M^{-1}`.
Attributes
----------
Expand All @@ -316,17 +321,19 @@ class WhiteNoiseSampler:
"""

def __init__(self, space, rng, *, precondition=True, M=None,
solver_parameters=None):
if solver_parameters is None:
solver_parameters = {}
mfn_solver_parameters=None, ksp_solver_parameters=None):
if mfn_solver_parameters is None:
mfn_solver_parameters = {}
else:
solver_parameters = dict(solver_parameters)
if solver_parameters.get("mfn_type", "krylov") != "krylov":
mfn_solver_parameters = dict(mfn_solver_parameters)
if mfn_solver_parameters.get("mfn_type", "krylov") != "krylov":
raise ValueError("Invalid mfn_type")
if solver_parameters.get("fn_type", "sqrt") != "sqrt":
if mfn_solver_parameters.get("fn_type", "sqrt") != "sqrt":
raise ValueError("Invalid fn_type")
solver_parameters.update({"mfn_type": "krylov",
"fn_type": "sqrt"})
mfn_solver_parameters.update({"mfn_type": "krylov",
"fn_type": "sqrt"})
if ksp_solver_parameters is None:
ksp_solver_parameters = {}

if not issubclass(space_dtype(space), np.floating):
raise ValueError("Real space required")
Expand All @@ -341,21 +348,31 @@ def __init__(self, space, rng, *, precondition=True, M=None,
else:
pc = None

def mult(x, y):
if pc is not None:
x = var_copy(x)
var_set_values(x, var_get_values(x) / pc)
matrix_multiply(M, x, tensor=y)
if pc is not None:
var_set_values(y, var_get_values(y) / pc)

self._space = space
self._M = M
self._rng = rng
self._pc = pc
self._mfn = MatrixFunctionSolver(
MatrixFreeMatrix(space, space, mult),
solver_parameters=solver_parameters)
self._mfn_solver_parameters = copy_parameters(mfn_solver_parameters)
self._ksp_solver_parameters = copy_parameters(ksp_solver_parameters)

@cached_property
def _mfn(self):
def mult(x, y):
if self._pc is not None:
x = var_copy(x)
var_set_values(x, var_get_values(x) / self._pc)
matrix_multiply(self._M, x, tensor=y)
if self._pc is not None:
var_set_values(y, var_get_values(y) / self._pc)

return MatrixFunctionSolver(
MatrixFreeMatrix(self.space, self.space, mult),
solver_parameters=self._mfn_solver_parameters)

@cached_property
def _ksp(self):
return backend_LinearSolver(
self._M, solver_parameters=self._ksp_solver_parameters)

@property
def space(self):
Expand All @@ -365,21 +382,46 @@ def space(self):
def rng(self):
return self._rng

def sample(self):
"""Generate a new sample.
def dual_sample(self):
r"""Generate a new sample in the dual space.
The result is given by
.. math::
X = \Xi^{-T} \sqrt{ \Xi^T M \Xi } Z.
Returns
-------
:class:`firedrake.function.Function` or \
:class:`firedrake.cofunction.Cofunction`
:class:`firedrake.cofunction.Cofunction`
The sample.
"""

Z = Function(self.space)
var_set_values(
Z, self.rng.standard_normal(var_local_size(Z), dtype=var_dtype(Z)))
X = Function(self.space)
X = Cofunction(self.space.dual())
self._mfn.solve(Z, X)
if self._pc is not None:
var_set_values(X, var_get_values(X) * self._pc)
return X

def sample(self):
r"""Generate a new sample.
The result is given by
.. math::
X = M^{-1} \Xi^{-T} \sqrt{ \Xi^T M \Xi } Z.
Returns
-------
:class:`firedrake.function.Function`
The sample.
"""

Y = self.dual_sample()
X = Function(self.space)
self._ksp.solve(X, Y)
return X

0 comments on commit 6c27141

Please sign in to comment.