Skip to content

Commit

Permalink
Tidying
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmaddison committed May 20, 2024
1 parent 1b7614b commit 1e20f65
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 59 deletions.
2 changes: 1 addition & 1 deletion tests/base/test_block_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, space, alpha):

@no_float_overloading
def mult_add(self, x, y):
y.addto(alpha * x.value)
y.addto(alpha * x)

alpha = 0.5
b = Float(-2.0)
Expand Down
3 changes: 2 additions & 1 deletion tests/firedrake/test_hessian_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def forward_J(m):
solver_parameters={"ksp_type": "cg",
"ksp_atol": 1.0e-12,
"ksp_rtol": 1.0e-12},
pc_fn=pc_fn)
pc_fn=pc_fn,
nullspace=nullspace)
H_solver.solve(
v, b_ref)
ksp_its = H_solver.ksp.getIterationNumber()
Expand Down
28 changes: 15 additions & 13 deletions tlm_adjoint/block_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,7 @@ class MixedSpace(PETScVecInterface, Sequence):
mixed_space.from_petsc(u_petsc, ((u_0, u_1), u_2))
:arg spaces: Defines the split space. A :class:`Sequence` whose elements
are backend space or :class:`.TypedSpace` objects, or similar
:class:`Sequence` objects.
:arg spaces: Defines the split space.
"""

def __init__(self, spaces):
Expand Down Expand Up @@ -618,24 +616,24 @@ class BlockMatrix(Matrix, MutableMapping):
r"""A matrix defining a mapping :math:`A` mapping :math:`V \rightarrow W`,
where :math:`V` and :math:`W` are defined by mixed spaces.
:arg arg_spaces: Defines the space `V`.
:arg action_spaces: Defines the space `W`.
:arg arg_space: Defines the space `V`.
:arg action_space: Defines the space `W`.
:arg block: A :class:`Mapping` defining the blocks of the matrix. Items are
`((i, j), block)` where the block in the `i` th and `j` th column is
defined by `block`. Each `block` is a
:class:`tlm_adjoint.block_system.Matrix`, or `None` to indicate a zero
block.
"""

def __init__(self, arg_spaces, action_spaces, blocks=None):
if not isinstance(arg_spaces, MixedSpace):
arg_spaces = MixedSpace(arg_spaces)
if not isinstance(action_spaces, MixedSpace):
action_spaces = MixedSpace(action_spaces)
def __init__(self, arg_space, action_space, blocks=None):
if not isinstance(arg_space, MixedSpace):
arg_space = MixedSpace(arg_space)
if not isinstance(action_space, MixedSpace):
action_space = MixedSpace(action_space)
if not isinstance(blocks, Mapping):
blocks = {(0, 0): blocks}

super().__init__(arg_spaces, action_spaces)
super().__init__(arg_space, action_space)
self._blocks = {}

if blocks is not None:
Expand Down Expand Up @@ -831,8 +829,14 @@ def pc_fn(u, b):

def __init__(self, A, *, nullspace=None, solver_parameters=None,
pc_fn=None, comm=None):
if nullspace is None:
nullspace = NoneNullspace()
elif isinstance(nullspace, Sequence):
nullspace = BlockNullspace(nullspace)
if not isinstance(A, BlockMatrix):
A = BlockMatrix((A.arg_space,), (A.action_space,), A)
if not isinstance(nullspace, NoneNullspace):
nullspace = BlockNullspace((nullspace,))
if solver_parameters is None:
solver_parameters = {}
if pc_fn is None:
Expand Down Expand Up @@ -883,7 +887,6 @@ def solve(self, u, b, *,
pc_fn = self._pc_fn
if not isinstance(u, Sequence):
u = (u,)

pc_fn_u = pc_fn

def pc_fn(u, b):
Expand All @@ -892,7 +895,6 @@ def pc_fn(u, b):

if not isinstance(b, Sequence):
b = (b,)

pc_fn_b = pc_fn

def pc_fn(u, b):
Expand Down
2 changes: 1 addition & 1 deletion tlm_adjoint/firedrake/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def subtract_adjoint_derivative_actions(self, adj_x, nl_deps, dep_Bs):
if isinstance(dF_comp, ufl.classes.Form):
dF_comp = ufl.classes.Form(
[integral.reconstruct(integrand=ufl.conj(integral.integrand())) # noqa: E501
for integral in dF_comp.integrals()]) # noqa: E501
for integral in dF_comp.integrals()])
else:
if complex_mode:
# See Firedrake issue #3346
Expand Down
55 changes: 23 additions & 32 deletions tlm_adjoint/firedrake/block_system.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
"""Firedrake specific extensions to :mod:`tlm_adjoint.block_system`.
"""

from .backend import TestFunction, backend_assemble, backend_DirichletBC
from ..interface import space_eq, var_axpy, var_inner, var_new

from ..block_system import (
BlockMatrix as _BlockMatrix, BlockNullspace, LinearSolver as _LinearSolver,
Matrix, MixedSpace, NoneNullspace, Nullspace, TypedSpace)
from ..interface import space_eq

from firedrake import Constant, DirichletBC, Function, TestFunction, assemble

import ufl
from .backend_interface import assemble, matrix_multiply
from .variables import Constant, Function

from collections.abc import Sequence
try:
import mpi4py.MPI as MPI
except ModuleNotFoundError:
MPI = None
import ufl

__all__ = \
[
Expand All @@ -23,14 +21,14 @@

"Nullspace",
"NoneNullspace",
"BlockNullspace",
"ConstantNullspace",
"UnityNullspace",
"DirichletBCNullspace",
"BlockNullspace",

"Matrix",
"PETScMatrix",
"BlockMatrix",
"PETScMatrix",
"form_matrix",

"LinearSolver"
Expand Down Expand Up @@ -114,11 +112,8 @@ def __init__(self, space, *, alpha=1.0):

@staticmethod
def _correct(x, y, u, v, *, alpha=1.0):
with x.dat.vec_ro as x_v, u.dat.vec_ro as u_v:
u_x = x_v.dot(u_v)

with y.dat.vec as y_v, v.dat.vec_ro as v_v:
y_v.axpy(alpha * u_x, v_v)
u_x = var_inner(x, u)
var_axpy(y, alpha * u_x, v)

def apply_nullspace_transformation_lhs_right(self, x):
if not space_eq(x.function_space(), self._space):
Expand Down Expand Up @@ -176,9 +171,9 @@ def __init__(self, bcs, *, alpha=1.0):
raise ValueError("Homogeneous boundary conditions required")

super().__init__()
self._space = space
self._bcs = bcs
self._alpha = alpha
self._c = Function(space)

def apply_nullspace_transformation_lhs_right(self, x):
apply_bcs(x, self._bcs)
Expand All @@ -187,15 +182,12 @@ def apply_nullspace_transformation_lhs_left(self, y):
apply_bcs(y, self._bcs)

def _constraint_correct_lhs(self, x, y, *, alpha=1.0):
with self._c.dat.vec_wo as c_v:
c_v.zeroEntries()

apply_bcs(self._c,
tuple(DirichletBC(x.function_space(), x, bc.sub_domain)
for bc in self._bcs))

with self._c.dat.vec_ro as c_v, y.dat.vec as y_v:
y_v.axpy(alpha, c_v)
c = var_new(y)
apply_bcs(
c,
tuple(backend_DirichletBC(x.function_space(), x, bc.sub_domain)
for bc in self._bcs))
var_axpy(y, alpha, c)

def constraint_correct_lhs(self, x, y):
self._constraint_correct_lhs(x, y, alpha=self._alpha)
Expand All @@ -206,21 +198,20 @@ def pc_constraint_correct_soln(self, u, b):

class PETScMatrix(Matrix):
r"""A :class:`tlm_adjoint.block_system.Matrix` associated with a
:class:`firedrake.matrix.Matrix` :math:`A` mapping :math:`V \rightarrow W`.
:class:`firedrake.matrix.Matrix` :math:`A` defining a mapping
:math:`V \rightarrow W`.
:arg arg_space: Defines the space `V`.
:arg action_space: Defines the space `W`.
:arg a: The :class:`firedrake.matrix.Matrix`.
"""

def __init__(self, arg_space, action_space, a):
def __init__(self, arg_space, action_space, A):
super().__init__(arg_space, action_space)
self._matrix = a
self._A = A

def mult_add(self, x, y):
matrix = self._matrix.petscmat
with x.dat.vec_ro as x_v, y.dat.vec as y_v:
matrix.multAdd(x_v, y_v, y_v)
matrix_multiply(self._A, x, tensor=y, addto=True)


def form_matrix(a, *args, **kwargs):
Expand All @@ -239,7 +230,7 @@ def form_matrix(a, *args, **kwargs):

return PETScMatrix(
trial.function_space(), test.function_space().dual(),
assemble(a, *args, **kwargs))
backend_assemble(a, *args, **kwargs))


class BlockMatrix(_BlockMatrix):
Expand Down
21 changes: 15 additions & 6 deletions tlm_adjoint/firedrake/hessian_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
var_axpy_conjugate, var_copy, var_copy_conjugate, var_dtype, var_inner,
var_space_type)

from ..block_system import (
BlockNullspace, LinearSolver, Matrix, NoneNullspace, Preconditioner,
TypedSpace, iter_sub, tuple_sub)
from ..block_system import Preconditioner, iter_sub, tuple_sub
from ..eigendecomposition import eigendecompose
from ..manager import manager_disabled

from .block_system import (
BlockNullspace, LinearSolver, Matrix, NoneNullspace, TypedSpace)

from collections.abc import Sequence
import numpy as np
import petsc4py.PETSc as PETSc
Expand Down Expand Up @@ -76,13 +77,21 @@ class HessianLinearSolver(LinearSolver):
:class:`firedrake.cofunction.Cofunction`, or a :class:`Sequence` of
:class:`firedrake.function.Function` or
:class:`firedrake.cofunction.Cofunction` objects, defining the control.
:arg nullspace: A :class:`.Nullspace` or a :class:`Sequence` of
:class:`.Nullspace` objects defining the nullspace and left nullspace
of the Hessian matrix. `None` indicates a :class:`.NoneNullspace`.
Remaining arguments are passed to the
:class:`tlm_adjoint.block_system.LinearSolver` constructor.
"""

def __init__(self, H, M, *args, **kwargs):
super().__init__(HessianMatrix(H, M), *args, **kwargs)
def __init__(self, H, M, *args, nullspace=None, **kwargs):
if nullspace is None:
nullspace = NoneNullspace()
elif not isinstance(nullspace, (NoneNullspace, BlockNullspace)):
nullspace = BlockNullspace(nullspace)
super().__init__(HessianMatrix(H, M), *args, nullspace=nullspace,
**kwargs)

@manager_disabled()
def solve(self, u, b, **kwargs):
Expand All @@ -104,7 +113,7 @@ def solve(self, u, b, **kwargs):
conjugate of the right-hand-side :math:`b`.
Remaining arguments are handed to the
:class:`tlm_adjoint.block_system.LinearSolver`` method.
:meth:`tlm_adjoint.block_system.LinearSolver.solve` method.
"""

if is_var(b):
Expand Down
8 changes: 4 additions & 4 deletions tlm_adjoint/firedrake/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
FiniteElement, TensorElement, TestFunction, VectorElement,
backend_Cofunction, backend_Constant, backend_Function, backend_ScalarType)
from ..interface import (
SpaceInterface, VariableInterface, add_replacement_interface, space_comm,
space_dtype, space_eq, register_subtract_adjoint_derivative_action,
space_id, subtract_adjoint_derivative_action_base, var_comm, var_dtype,
var_is_cached, var_is_static, var_linf_norm, var_lock_state,
SpaceInterface, VariableInterface, add_replacement_interface,
register_subtract_adjoint_derivative_action, space_comm, space_dtype,
space_eq, space_id, subtract_adjoint_derivative_action_base, var_comm,
var_dtype, var_is_cached, var_is_static, var_linf_norm, var_lock_state,
var_scalar_value, var_space, var_space_type)

from ..caches import Caches
Expand Down
2 changes: 1 addition & 1 deletion tlm_adjoint/overloaded_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from sympy.utilities.lambdify import lambdastr
try:
from sympy.printing.numpy import NumPyPrinter
except ModuleNotFoundError:
except ImportError:
from sympy.printing.pycode import NumPyPrinter


Expand Down

0 comments on commit 1e20f65

Please sign in to comment.