From 1e20f656369391f89e0b53881c21719b1c3a2eb1 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 20 May 2024 11:15:05 +0100 Subject: [PATCH] Tidying --- tests/base/test_block_system.py | 2 +- tests/firedrake/test_hessian_system.py | 3 +- tlm_adjoint/block_system.py | 28 +++++++------ tlm_adjoint/firedrake/assembly.py | 2 +- tlm_adjoint/firedrake/block_system.py | 55 +++++++++++-------------- tlm_adjoint/firedrake/hessian_system.py | 21 +++++++--- tlm_adjoint/firedrake/variables.py | 8 ++-- tlm_adjoint/overloaded_float.py | 2 +- 8 files changed, 62 insertions(+), 59 deletions(-) diff --git a/tests/base/test_block_system.py b/tests/base/test_block_system.py index 09598793c..214b0f948 100644 --- a/tests/base/test_block_system.py +++ b/tests/base/test_block_system.py @@ -27,7 +27,7 @@ def __init__(self, space, alpha): @no_float_overloading def mult_add(self, x, y): - y.addto(alpha * x.value) + y.addto(alpha * x) alpha = 0.5 b = Float(-2.0) diff --git a/tests/firedrake/test_hessian_system.py b/tests/firedrake/test_hessian_system.py index f054a729f..7e22f9df3 100644 --- a/tests/firedrake/test_hessian_system.py +++ b/tests/firedrake/test_hessian_system.py @@ -128,7 +128,8 @@ def forward_J(m): solver_parameters={"ksp_type": "cg", "ksp_atol": 1.0e-12, "ksp_rtol": 1.0e-12}, - pc_fn=pc_fn) + pc_fn=pc_fn, + nullspace=nullspace) H_solver.solve( v, b_ref) ksp_its = H_solver.ksp.getIterationNumber() diff --git a/tlm_adjoint/block_system.py b/tlm_adjoint/block_system.py index 1b3669b33..21335a66c 100644 --- a/tlm_adjoint/block_system.py +++ b/tlm_adjoint/block_system.py @@ -246,9 +246,7 @@ class MixedSpace(PETScVecInterface, Sequence): mixed_space.from_petsc(u_petsc, ((u_0, u_1), u_2)) - :arg spaces: Defines the split space. A :class:`Sequence` whose elements - are backend space or :class:`.TypedSpace` objects, or similar - :class:`Sequence` objects. + :arg spaces: Defines the split space. """ def __init__(self, spaces): @@ -618,8 +616,8 @@ class BlockMatrix(Matrix, MutableMapping): r"""A matrix defining a mapping :math:`A` mapping :math:`V \rightarrow W`, where :math:`V` and :math:`W` are defined by mixed spaces. - :arg arg_spaces: Defines the space `V`. - :arg action_spaces: Defines the space `W`. + :arg arg_space: Defines the space `V`. + :arg action_space: Defines the space `W`. :arg block: A :class:`Mapping` defining the blocks of the matrix. Items are `((i, j), block)` where the block in the `i` th and `j` th column is defined by `block`. Each `block` is a @@ -627,15 +625,15 @@ class BlockMatrix(Matrix, MutableMapping): block. """ - def __init__(self, arg_spaces, action_spaces, blocks=None): - if not isinstance(arg_spaces, MixedSpace): - arg_spaces = MixedSpace(arg_spaces) - if not isinstance(action_spaces, MixedSpace): - action_spaces = MixedSpace(action_spaces) + def __init__(self, arg_space, action_space, blocks=None): + if not isinstance(arg_space, MixedSpace): + arg_space = MixedSpace(arg_space) + if not isinstance(action_space, MixedSpace): + action_space = MixedSpace(action_space) if not isinstance(blocks, Mapping): blocks = {(0, 0): blocks} - super().__init__(arg_spaces, action_spaces) + super().__init__(arg_space, action_space) self._blocks = {} if blocks is not None: @@ -831,8 +829,14 @@ def pc_fn(u, b): def __init__(self, A, *, nullspace=None, solver_parameters=None, pc_fn=None, comm=None): + if nullspace is None: + nullspace = NoneNullspace() + elif isinstance(nullspace, Sequence): + nullspace = BlockNullspace(nullspace) if not isinstance(A, BlockMatrix): A = BlockMatrix((A.arg_space,), (A.action_space,), A) + if not isinstance(nullspace, NoneNullspace): + nullspace = BlockNullspace((nullspace,)) if solver_parameters is None: solver_parameters = {} if pc_fn is None: @@ -883,7 +887,6 @@ def solve(self, u, b, *, pc_fn = self._pc_fn if not isinstance(u, Sequence): u = (u,) - pc_fn_u = pc_fn def pc_fn(u, b): @@ -892,7 +895,6 @@ def pc_fn(u, b): if not isinstance(b, Sequence): b = (b,) - pc_fn_b = pc_fn def pc_fn(u, b): diff --git a/tlm_adjoint/firedrake/assembly.py b/tlm_adjoint/firedrake/assembly.py index 87d1e0609..5f08e19aa 100644 --- a/tlm_adjoint/firedrake/assembly.py +++ b/tlm_adjoint/firedrake/assembly.py @@ -133,7 +133,7 @@ def subtract_adjoint_derivative_actions(self, adj_x, nl_deps, dep_Bs): if isinstance(dF_comp, ufl.classes.Form): dF_comp = ufl.classes.Form( [integral.reconstruct(integrand=ufl.conj(integral.integrand())) # noqa: E501 - for integral in dF_comp.integrals()]) # noqa: E501 + for integral in dF_comp.integrals()]) else: if complex_mode: # See Firedrake issue #3346 diff --git a/tlm_adjoint/firedrake/block_system.py b/tlm_adjoint/firedrake/block_system.py index 75e8a15d6..863c72388 100644 --- a/tlm_adjoint/firedrake/block_system.py +++ b/tlm_adjoint/firedrake/block_system.py @@ -1,20 +1,18 @@ """Firedrake specific extensions to :mod:`tlm_adjoint.block_system`. """ +from .backend import TestFunction, backend_assemble, backend_DirichletBC +from ..interface import space_eq, var_axpy, var_inner, var_new + from ..block_system import ( BlockMatrix as _BlockMatrix, BlockNullspace, LinearSolver as _LinearSolver, Matrix, MixedSpace, NoneNullspace, Nullspace, TypedSpace) -from ..interface import space_eq - -from firedrake import Constant, DirichletBC, Function, TestFunction, assemble -import ufl +from .backend_interface import assemble, matrix_multiply +from .variables import Constant, Function from collections.abc import Sequence -try: - import mpi4py.MPI as MPI -except ModuleNotFoundError: - MPI = None +import ufl __all__ = \ [ @@ -23,14 +21,14 @@ "Nullspace", "NoneNullspace", + "BlockNullspace", "ConstantNullspace", "UnityNullspace", "DirichletBCNullspace", - "BlockNullspace", "Matrix", - "PETScMatrix", "BlockMatrix", + "PETScMatrix", "form_matrix", "LinearSolver" @@ -114,11 +112,8 @@ def __init__(self, space, *, alpha=1.0): @staticmethod def _correct(x, y, u, v, *, alpha=1.0): - with x.dat.vec_ro as x_v, u.dat.vec_ro as u_v: - u_x = x_v.dot(u_v) - - with y.dat.vec as y_v, v.dat.vec_ro as v_v: - y_v.axpy(alpha * u_x, v_v) + u_x = var_inner(x, u) + var_axpy(y, alpha * u_x, v) def apply_nullspace_transformation_lhs_right(self, x): if not space_eq(x.function_space(), self._space): @@ -176,9 +171,9 @@ def __init__(self, bcs, *, alpha=1.0): raise ValueError("Homogeneous boundary conditions required") super().__init__() + self._space = space self._bcs = bcs self._alpha = alpha - self._c = Function(space) def apply_nullspace_transformation_lhs_right(self, x): apply_bcs(x, self._bcs) @@ -187,15 +182,12 @@ def apply_nullspace_transformation_lhs_left(self, y): apply_bcs(y, self._bcs) def _constraint_correct_lhs(self, x, y, *, alpha=1.0): - with self._c.dat.vec_wo as c_v: - c_v.zeroEntries() - - apply_bcs(self._c, - tuple(DirichletBC(x.function_space(), x, bc.sub_domain) - for bc in self._bcs)) - - with self._c.dat.vec_ro as c_v, y.dat.vec as y_v: - y_v.axpy(alpha, c_v) + c = var_new(y) + apply_bcs( + c, + tuple(backend_DirichletBC(x.function_space(), x, bc.sub_domain) + for bc in self._bcs)) + var_axpy(y, alpha, c) def constraint_correct_lhs(self, x, y): self._constraint_correct_lhs(x, y, alpha=self._alpha) @@ -206,21 +198,20 @@ def pc_constraint_correct_soln(self, u, b): class PETScMatrix(Matrix): r"""A :class:`tlm_adjoint.block_system.Matrix` associated with a - :class:`firedrake.matrix.Matrix` :math:`A` mapping :math:`V \rightarrow W`. + :class:`firedrake.matrix.Matrix` :math:`A` defining a mapping + :math:`V \rightarrow W`. :arg arg_space: Defines the space `V`. :arg action_space: Defines the space `W`. :arg a: The :class:`firedrake.matrix.Matrix`. """ - def __init__(self, arg_space, action_space, a): + def __init__(self, arg_space, action_space, A): super().__init__(arg_space, action_space) - self._matrix = a + self._A = A def mult_add(self, x, y): - matrix = self._matrix.petscmat - with x.dat.vec_ro as x_v, y.dat.vec as y_v: - matrix.multAdd(x_v, y_v, y_v) + matrix_multiply(self._A, x, tensor=y, addto=True) def form_matrix(a, *args, **kwargs): @@ -239,7 +230,7 @@ def form_matrix(a, *args, **kwargs): return PETScMatrix( trial.function_space(), test.function_space().dual(), - assemble(a, *args, **kwargs)) + backend_assemble(a, *args, **kwargs)) class BlockMatrix(_BlockMatrix): diff --git a/tlm_adjoint/firedrake/hessian_system.py b/tlm_adjoint/firedrake/hessian_system.py index c31e11b44..fd57feb7c 100644 --- a/tlm_adjoint/firedrake/hessian_system.py +++ b/tlm_adjoint/firedrake/hessian_system.py @@ -3,12 +3,13 @@ var_axpy_conjugate, var_copy, var_copy_conjugate, var_dtype, var_inner, var_space_type) -from ..block_system import ( - BlockNullspace, LinearSolver, Matrix, NoneNullspace, Preconditioner, - TypedSpace, iter_sub, tuple_sub) +from ..block_system import Preconditioner, iter_sub, tuple_sub from ..eigendecomposition import eigendecompose from ..manager import manager_disabled +from .block_system import ( + BlockNullspace, LinearSolver, Matrix, NoneNullspace, TypedSpace) + from collections.abc import Sequence import numpy as np import petsc4py.PETSc as PETSc @@ -76,13 +77,21 @@ class HessianLinearSolver(LinearSolver): :class:`firedrake.cofunction.Cofunction`, or a :class:`Sequence` of :class:`firedrake.function.Function` or :class:`firedrake.cofunction.Cofunction` objects, defining the control. + :arg nullspace: A :class:`.Nullspace` or a :class:`Sequence` of + :class:`.Nullspace` objects defining the nullspace and left nullspace + of the Hessian matrix. `None` indicates a :class:`.NoneNullspace`. Remaining arguments are passed to the :class:`tlm_adjoint.block_system.LinearSolver` constructor. """ - def __init__(self, H, M, *args, **kwargs): - super().__init__(HessianMatrix(H, M), *args, **kwargs) + def __init__(self, H, M, *args, nullspace=None, **kwargs): + if nullspace is None: + nullspace = NoneNullspace() + elif not isinstance(nullspace, (NoneNullspace, BlockNullspace)): + nullspace = BlockNullspace(nullspace) + super().__init__(HessianMatrix(H, M), *args, nullspace=nullspace, + **kwargs) @manager_disabled() def solve(self, u, b, **kwargs): @@ -104,7 +113,7 @@ def solve(self, u, b, **kwargs): conjugate of the right-hand-side :math:`b`. Remaining arguments are handed to the - :class:`tlm_adjoint.block_system.LinearSolver`` method. + :meth:`tlm_adjoint.block_system.LinearSolver.solve` method. """ if is_var(b): diff --git a/tlm_adjoint/firedrake/variables.py b/tlm_adjoint/firedrake/variables.py index bd771c41b..7b751a7f8 100644 --- a/tlm_adjoint/firedrake/variables.py +++ b/tlm_adjoint/firedrake/variables.py @@ -5,10 +5,10 @@ FiniteElement, TensorElement, TestFunction, VectorElement, backend_Cofunction, backend_Constant, backend_Function, backend_ScalarType) from ..interface import ( - SpaceInterface, VariableInterface, add_replacement_interface, space_comm, - space_dtype, space_eq, register_subtract_adjoint_derivative_action, - space_id, subtract_adjoint_derivative_action_base, var_comm, var_dtype, - var_is_cached, var_is_static, var_linf_norm, var_lock_state, + SpaceInterface, VariableInterface, add_replacement_interface, + register_subtract_adjoint_derivative_action, space_comm, space_dtype, + space_eq, space_id, subtract_adjoint_derivative_action_base, var_comm, + var_dtype, var_is_cached, var_is_static, var_linf_norm, var_lock_state, var_scalar_value, var_space, var_space_type) from ..caches import Caches diff --git a/tlm_adjoint/overloaded_float.py b/tlm_adjoint/overloaded_float.py index b6d6eab66..337d8606a 100644 --- a/tlm_adjoint/overloaded_float.py +++ b/tlm_adjoint/overloaded_float.py @@ -38,7 +38,7 @@ from sympy.utilities.lambdify import lambdastr try: from sympy.printing.numpy import NumPyPrinter -except ModuleNotFoundError: +except ImportError: from sympy.printing.pycode import NumPyPrinter