Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Firedrake backend: Spatial white noise sampling #575

Merged
merged 3 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

extensions = ["autoapi.extension",
"nbsphinx",
"numpydoc",
"sphinx.ext.intersphinx",
"sphinx_rtd_theme"]

Expand All @@ -23,6 +24,8 @@

nbsphinx_execute = "auto"

numpydoc_validation_checks = {"all", "GL08"}

html_theme = "sphinx_rtd_theme"
html_theme_options = {"display_version": False}

Expand Down
41 changes: 40 additions & 1 deletion tests/firedrake/test_block_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from tlm_adjoint.firedrake.block_system import (
BlockMatrix, BlockNullspace, ConstantNullspace, DirichletBCNullspace,
Eigensolver, LinearSolver as _BlockLinearSolver, MatrixFreeMatrix,
MatrixFunctionSolver, UnityNullspace, form_matrix)
MatrixFunctionSolver, UnityNullspace, WhiteNoiseSampler, form_matrix)

from .test_base import *

import itertools
import mpi4py.MPI as MPI # noqa: N817
import numbers
import numpy as np
Expand Down Expand Up @@ -604,3 +605,41 @@ def test_M_root(setup_test, test_leaks):
error = var_copy(m_u_ref)
var_axpy(error, -1.0, M_u)
assert var_linf_norm(error) < 1.0e-15


@pytest.mark.firedrake
@pytest.mark.skipif(SLEPc is None, reason="SLEPc not available")
@pytest.mark.skipif(DEFAULT_COMM.size > 1, reason="serial only")
@pytest.mark.skipif(complex_mode, reason="real only")
@pytest.mark.parametrize("precondition", [False, True])
@seed_test
def test_white_noise_sampler(setup_test, test_leaks,
precondition):
mesh = UnitIntervalMesh(3)
space = FunctionSpace(mesh, "Lagrange", 3)
test = TestFunction(space)
trial = TrialFunction(space)

M = assemble(inner(trial, test) * dx)
V_ref = np.zeros((space_global_size(space), space_global_size(space)),
dtype=space_dtype(space))
for i, j in itertools.product(range(space_global_size(space)),
range(space_global_size(space))):
V_ref[i, j] = M.petscmat[i, j]

rng = np.random.default_rng(
np.random.SeedSequence(entropy=np.random.get_state()[1][0]))
sampler = WhiteNoiseSampler(
space, rng, precondition=precondition,
solver_parameters={"mfn_tol": 1.0e-10})

V = np.zeros((space_global_size(space), space_global_size(space)),
dtype=space_dtype(space))
N = 1000
for _ in range(N):
X = var_get_values(sampler.sample())
V += np.outer(X, X)
V /= N
error = abs(V - V_ref).max()
print(f"{error=}")
assert abs(error) < 0.015
7 changes: 3 additions & 4 deletions tlm_adjoint/block_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,9 +1305,8 @@ def solve(self, u, v):
u_petsc.to_petsc(u)
v_petsc.to_petsc(v)

self.mfn.solve(u_petsc.vec, v_petsc.vec)
if self.mfn.getConvergedReason() <= 0:
raise RuntimeError("Convergence failure")
self.mfn.solve(u_petsc.vec, v_petsc.vec)
if self.mfn.getConvergedReason() <= 0:
raise RuntimeError("Convergence failure")

with paused_space_type_checking():
v_petsc.from_petsc(v)
127 changes: 124 additions & 3 deletions tlm_adjoint/firedrake/block_system.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Firedrake specific extensions to :mod:`tlm_adjoint.block_system`.
"""

from .backend import TestFunction, backend_assemble, backend_DirichletBC
from ..interface import packed, space_eq, var_axpy, var_inner, var_new
from .backend import (
TestFunction, TrialFunction, backend_assemble, backend_DirichletBC, dx,
inner)
from ..interface import (
packed, space_dtype, space_eq, var_axpy, var_copy, var_dtype,
var_get_values, var_inner, var_local_size, var_new, var_set_values)

from ..block_system import (
BlockMatrix as _BlockMatrix, BlockNullspace, Eigensolver,
Expand All @@ -12,6 +16,7 @@
from .backend_interface import assemble, matrix_multiply
from .variables import Constant, Function

import numpy as np
import ufl

__all__ = \
Expand All @@ -34,7 +39,9 @@

"LinearSolver",
"Eigensolver",
"MatrixFunctionSolver"
"MatrixFunctionSolver",

"WhiteNoiseSampler"
]


Expand Down Expand Up @@ -238,3 +245,117 @@ def __init__(self, A, *args, **kwargs):
if isinstance(A, ufl.classes.Form):
A = form_matrix(A)
super().__init__(A, *args, **kwargs)


class WhiteNoiseSampler:
r"""White noise sampling.

Utility class for drawing independent spatial white noise samples.
Generates a sample using

.. math::

X = \Xi^{-T} \sqrt{ \Xi^T M \Xi } Z,

where

- :math:`M` is the mass matrix.
- :math:`\Xi` is a preconditioner.
- :math:`Z` is a vector whose elements are independent standard
Gaussian samples.

The matrix square root is computed using SLEPc.

Parameters
----------

space : :class:`firedrake.functionspaceimpl.WithGeometry`
The function space.
rng : :class:`numpy.random._generator.Generator`
Pseudorandom number generator.
precondition : :class:`bool`
If `True` then :math:`\Xi` is set equal to the inverse of the
(principal) square root of the diagonal of :math:`M`. Otherwise it is
set equal to the identity.
M : :class:`firedrake.matrix.Matrix`
Mass matrix. Constructed by finite element assembly if not supplied.
solver_parameters : :class:`Mapping`
Solver parameters.

Attributes
----------

space : :class:`firedrake.functionspaceimpl.WithGeometry`
The function space.
rng : :class:`numpy.random._generator.Generator`
Pseudorandom number generator.
"""

def __init__(self, space, rng, *, precondition=True, M=None,
solver_parameters=None):
if solver_parameters is None:
solver_parameters = {}
else:
solver_parameters = dict(solver_parameters)
if solver_parameters.get("mfn_type", "krylov") != "krylov":
raise ValueError("Invalid mfn_type")
if solver_parameters.get("fn_type", "sqrt") != "sqrt":
raise ValueError("Invalid fn_type")
solver_parameters.update({"mfn_type": "krylov",
"fn_type": "sqrt"})

if not issubclass(space_dtype(space), np.floating):
raise ValueError("Real space required")
if M is None:
test = TestFunction(space)
trial = TrialFunction(space)
M = assemble(inner(trial, test) * dx)

if precondition:
M_diag = M.petscmat.getDiagonal()
pc = np.sqrt(M_diag.getArray(True))
else:
pc = None

def mult(x, y):
if pc is not None:
x = var_copy(x)
var_set_values(x, var_get_values(x) / pc)
matrix_multiply(M, x, tensor=y)
if pc is not None:
var_set_values(y, var_get_values(y) / pc)

self._space = space
self._M = M
self._rng = rng
self._pc = pc
self._mfn = MatrixFunctionSolver(
MatrixFreeMatrix(space, space, mult),
solver_parameters=solver_parameters)

@property
def space(self):
return self._space

@property
def rng(self):
return self._rng

def sample(self):
"""Generate a new sample.

Returns
-------
X : :class:`firedrake.function.Function` or \
:class:`firedrake.cofunction.Cofunction`
The sample.
"""

Z = Function(self.space)
var_set_values(
Z, self.rng.standard_normal(var_local_size(Z), dtype=var_dtype(Z)))
X = Function(self.space)
self._mfn.solve(Z, X)
if self._pc is not None:
var_set_values(X, var_get_values(X) * self._pc)
return X