From d83eaf01eca1779bd44a20d752a5e8743b9fcd76 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 28 May 2024 09:39:07 +0100 Subject: [PATCH 1/4] Add Packed class, based on pyadjoint Enlist functionality --- tlm_adjoint/block_system.py | 42 ++++---- tlm_adjoint/cached_hessian.py | 31 +++--- tlm_adjoint/equation.py | 77 ++++++--------- tlm_adjoint/equations.py | 35 ++++--- tlm_adjoint/fenics/backend_interface.py | 28 +++--- tlm_adjoint/fenics/backend_patches.py | 14 +-- tlm_adjoint/fenics/caches.py | 22 ++--- tlm_adjoint/fenics/interpolation.py | 9 +- tlm_adjoint/fenics/solve.py | 8 +- tlm_adjoint/firedrake/backend_interface.py | 31 +++--- tlm_adjoint/firedrake/block_system.py | 12 +-- tlm_adjoint/firedrake/caches.py | 23 ++--- tlm_adjoint/firedrake/interpolation.py | 10 +- tlm_adjoint/firedrake/solve.py | 12 +-- tlm_adjoint/fixed_point.py | 22 +---- tlm_adjoint/hessian.py | 43 ++++---- tlm_adjoint/hessian_system.py | 16 +-- tlm_adjoint/interface.py | 71 +++++++++++-- tlm_adjoint/jax.py | 18 ++-- tlm_adjoint/linear_equation.py | 47 ++++----- tlm_adjoint/markers.py | 9 +- tlm_adjoint/optimization.py | 110 ++++++++------------- tlm_adjoint/tangent_linear.py | 13 +-- tlm_adjoint/tlm_adjoint.py | 52 ++-------- tlm_adjoint/verification.py | 35 +++---- 25 files changed, 346 insertions(+), 444 deletions(-) diff --git a/tlm_adjoint/block_system.py b/tlm_adjoint/block_system.py index 98a93500..3bd31987 100644 --- a/tlm_adjoint/block_system.py +++ b/tlm_adjoint/block_system.py @@ -70,8 +70,8 @@ """ from .interface import ( - comm_dup_cached, space_comm, space_default_space_type, space_eq, space_new, - var_assign, var_locked, var_zero) + comm_dup_cached, packed, space_comm, space_default_space_type, space_eq, + space_new, var_assign, var_locked, var_zero) from .manager import manager_disabled from .petsc import ( PETScOptions, PETScVec, PETScVecInterface, attach_destroy_finalizer) @@ -113,20 +113,23 @@ def expand(e): q = deque(map(expand, iterable)) while len(q) > 0: - e = q.popleft() - if isinstance(e, Sequence) and not isinstance(e, str): - q.extendleft(map(expand, reversed(e))) - else: + e = packed(q.popleft()) + if e.is_packed: + e, = e yield e + else: + q.extendleft(map(expand, reversed(e))) def tuple_sub(iterable, sequence): iterator = iter_sub(iterable) def tuple_sub(iterator, value): - if isinstance(value, Sequence) and not isinstance(value, str): + value = packed(value) + if value.is_packed: + return next(iterator) + else: return tuple(tuple_sub(iterator, e) for e in value) - return next(iterator) t = tuple_sub(iterator, sequence) @@ -252,17 +255,15 @@ class MixedSpace(PETScVecInterface, Sequence): def __init__(self, spaces): if isinstance(spaces, MixedSpace): spaces = spaces.split_space - elif isinstance(spaces, Sequence): - spaces = tuple(spaces) else: - spaces = (spaces,) + spaces = packed(spaces) flattened_spaces = tuple(space if isinstance(space, TypedSpace) else TypedSpace(space) for space in iter_sub(spaces)) spaces = tuple_sub(flattened_spaces, spaces) super().__init__(tuple(space.space for space in flattened_spaces)) - self._spaces = spaces + self._spaces = tuple_sub(spaces, spaces) self._flattened_spaces = flattened_spaces def __len__(self): @@ -489,10 +490,7 @@ class BlockNullspace(Nullspace, Sequence): """ def __init__(self, nullspaces): - if not isinstance(nullspaces, Sequence): - nullspaces = (nullspaces,) - - nullspaces = list(nullspaces) + nullspaces = list(packed(nullspaces)) for i, nullspace in enumerate(nullspaces): if nullspace is None: nullspaces[i] = NoneNullspace() @@ -502,8 +500,7 @@ def __init__(self, nullspaces): self._nullspaces = nullspaces def __new__(cls, nullspaces, *args, **kwargs): - if not isinstance(nullspaces, Sequence): - nullspaces = (nullspaces,) + nullspaces = packed(nullspaces) for nullspace in nullspaces: if nullspace is not None \ and not isinstance(nullspace, NoneNullspace): @@ -862,17 +859,18 @@ def solve(self, u, b, *, the solution. """ + u = packed(u) + b = packed(b) + pc_fn = self._pc_fn - if not isinstance(u, Sequence): - u = (u,) + if u.is_packed: pc_fn_u = pc_fn def pc_fn(u, b): u, = tuple(iter_sub(u)) pc_fn_u(u, b) - if not isinstance(b, Sequence): - b = (b,) + if b.is_packed: pc_fn_b = pc_fn def pc_fn(u, b): diff --git a/tlm_adjoint/cached_hessian.py b/tlm_adjoint/cached_hessian.py index 39c91084..907531fa 100644 --- a/tlm_adjoint/cached_hessian.py +++ b/tlm_adjoint/cached_hessian.py @@ -1,5 +1,5 @@ from .interface import ( - VariableStateLockDictionary, is_var, var_check_state_lock, var_id, + Packed, VariableStateLockDictionary, packed, var_check_state_lock, var_id, var_increment_state_lock, var_new, var_scalar_value) from .caches import clear_caches @@ -152,11 +152,10 @@ def __init__(self, J, *, manager=None, cache_adjoint=True): @restore_manager def compute_gradient(self, M, M0=None): - if is_var(M): - J_val, (dJ,) = self.compute_gradient( - (M,), - M0=None if M0 is None else (M0,)) - return J_val, dJ + M_packed = Packed(M) + M = tuple(M_packed) + if M0 is not None: + M0 = packed(M0) var_check_state_lock(self._J) @@ -173,15 +172,15 @@ def compute_gradient(self, M, M0=None): store_adjoint=self._cache_adjoint) reset_manager() - return J_val, dJ + return J_val, M_packed.unpack(dJ) @restore_manager def action(self, M, dM, M0=None): - if is_var(M): - J_val, dJ_val, (ddJ,) = self.action( - (M,), (dM,), - M0=None if M0 is None else (M0,)) - return J_val, dJ_val, ddJ + M_packed = Packed(M) + M = tuple(M_packed) + dM = packed(dM) + if M0 is not None: + M0 = packed(M0) var_check_state_lock(self._J) @@ -198,7 +197,7 @@ def action(self, M, dM, M0=None): store_adjoint=self._cache_adjoint) reset_manager() - return J_val, dJ_val, ddJ + return J_val, dJ_val, M_packed.unpack(ddJ) class CachedGaussNewton(GaussNewton, HessianOptimization): @@ -215,9 +214,7 @@ class CachedGaussNewton(GaussNewton, HessianOptimization): def __init__(self, X, R_inv_action, B_inv_action=None, *, manager=None): - if is_var(X): - X = (X,) - + X = packed(X) for x in X: var_increment_state_lock(x, self) @@ -225,7 +222,7 @@ def __init__(self, X, R_inv_action, B_inv_action=None, *, cache_adjoint=False) GaussNewton.__init__( self, R_inv_action, B_inv_action=B_inv_action) - self._X = tuple(X) + self._X = X def _setup_manager(self, M, dM, M0=None, *, annotate_tlm=False, solve_tlm=True): diff --git a/tlm_adjoint/equation.py b/tlm_adjoint/equation.py index d06d0f7b..3bbff6a4 100644 --- a/tlm_adjoint/equation.py +++ b/tlm_adjoint/equation.py @@ -1,12 +1,12 @@ from .interface import ( - check_space_types, is_var, var_id, var_is_alias, var_is_static, var_locked, - var_new, var_replacement, var_update_caches, var_update_state, var_zero) + Packed, check_space_types, is_var, packed, var_id, var_is_alias, + var_is_static, var_locked, var_new, var_replacement, var_update_caches, + var_update_state, var_zero) from .alias import WeakAlias, gc_disabled from .manager import manager as _manager from .manager import annotation_enabled, paused_manager, tlm_enabled -from collections.abc import Sequence import functools import inspect import itertools @@ -128,8 +128,8 @@ def __init__(self, X, deps, nl_deps=None, *, ic_deps=None, ic=None, adj_ic_deps=None, adj_ic=None, adj_type="conjugate_dual"): - if is_var(X): - X = (X,) + X_packed = Packed(X) + X = tuple(X_packed) X_ids = set(map(var_id, X)) dep_ids = {var_id(dep): i for i, dep in enumerate(deps)} for x in X: @@ -193,16 +193,14 @@ def __init__(self, X, deps, nl_deps=None, *, if adj_type in ["primal", "conjugate_dual"]: adj_type = tuple(adj_type for _ in X) - elif isinstance(adj_type, Sequence): - if len(adj_type) != len(X): - raise ValueError("Invalid adjoint type") - else: + if len(adj_type) != len(X): raise ValueError("Invalid adjoint type") for adj_x_type in adj_type: if adj_x_type not in {"primal", "conjugate_dual"}: raise ValueError("Invalid adjoint type") super().__init__() + self._packed = X_packed.mapped(lambda x: None) self._X = tuple(X) self._deps = tuple(deps) self._nl_deps = tuple(nl_deps) @@ -254,6 +252,9 @@ def x(self): x, = self._X return x + def _unpack(self, obj): + return self._packed.unpack(obj) + def X(self, m=None): """Return forward solution variables. @@ -397,7 +398,7 @@ def forward(self, X, deps=None): with var_locked(*(dep for dep in (eq_deps if deps is None else deps) if var_id(dep) not in X_ids)): var_update_caches(*eq_deps, value=deps) - self.forward_solve(X[0] if len(X) == 1 else X, deps=deps) + self.forward_solve(self._unpack(X), deps=deps) var_update_state(*X) var_update_caches(*self.X(), value=X) @@ -407,11 +408,9 @@ def forward_solve(self, X, deps=None): Can assume that the currently active :class:`.EquationManager` is paused. - :arg X: A variable if the forward solution has a single component, - otherwise a :class:`Sequence` of variables. May define an initial - guess, and should be set by this method. Subclasses may replace - this argument with `x` if the forward solution has a single - component. + :arg X: A variable or a :class:`Sequence` of variables storing the + solution. May define an initial guess, and should be set by this + method. :arg deps: A :class:`tuple` of variables, defining values for dependencies. Only the elements corresponding to `X` may be modified. `self.dependencies()` should be used if not supplied. @@ -428,8 +427,8 @@ def adjoint(self, adj_X, nl_deps, B, dep_Bs): returned. :arg nl_deps: A :class:`Sequence` of variables defining values for non-linear dependencies. Should not be modified. - :arg B: A sequence of variables defining the right-hand-side of the - adjoint equation. May be modified or returned. + :arg B: A :class:`Sequence` of variables defining the right-hand-side + of the adjoint equation. May be modified or returned. :arg dep_Bs: A :class:`Mapping` whose items are `(dep_index, dep_B)`. Each `dep_B` is an :class:`.AdjointRHS` which should be updated by subtracting adjoint derivative information computed by @@ -443,12 +442,11 @@ def adjoint(self, adj_X, nl_deps, B, dep_Bs): var_update_caches(*self.nonlinear_dependencies(), value=nl_deps) adj_X = self.adjoint_jacobian_solve( - adj_X if adj_X is None or len(adj_X) != 1 else adj_X[0], - nl_deps, B[0] if len(B) == 1 else B) + None if adj_X is None else self._unpack(adj_X), + nl_deps, self._unpack(B)) if adj_X is None: return None - elif is_var(adj_X): - adj_X = (adj_X,) + adj_X = packed(adj_X) var_update_state(*adj_X) for m, adj_x in enumerate(adj_X): @@ -456,7 +454,7 @@ def adjoint(self, adj_X, nl_deps, B, dep_Bs): rel_space_type=self.adj_X_type(m)) self.subtract_adjoint_derivative_actions( - adj_X[0] if len(adj_X) == 1 else adj_X, nl_deps, dep_Bs) + self._unpack(adj_X), nl_deps, dep_Bs) return tuple(adj_X) @@ -477,7 +475,7 @@ def adjoint_cached(self, adj_X, nl_deps, dep_Bs): var_update_caches(*self.nonlinear_dependencies(), value=nl_deps) self.subtract_adjoint_derivative_actions( - adj_X[0] if len(adj_X) == 1 else adj_X, nl_deps, dep_Bs) + self._unpack(adj_X), nl_deps, dep_Bs) def adjoint_derivative_action(self, nl_deps, dep_index, adj_X): """Return the action of the adjoint of a derivative of the forward @@ -489,10 +487,8 @@ def adjoint_derivative_action(self, nl_deps, dep_index, adj_X): :arg dep_index: An :class:`int`. The derivative is defined by differentiation of the forward residual with respect to `self.dependencies()[dep_index]`. - :arg adj_X: The adjoint solution. A variable if the adjoint solution - has a single component, otherwise a :class:`Sequence` of variables. - Should not be modified. Subclasses may replace this argument with - `adj_x` if the adjoint solution has a single component. + :arg adj_X: The adjoint solution. A variable or a :class:`Sequence` of + variables. Should not be modified. :returns: The action of the adjoint of a derivative on the adjoint solution. Will be passed to :func:`.subtract_adjoint_derivative_action`, and valid types depend @@ -510,10 +506,8 @@ def subtract_adjoint_derivative_actions(self, adj_X, nl_deps, dep_Bs): Can be overridden for an optimized implementation, but otherwise uses :meth:`.Equation.adjoint_derivative_action`. - :arg adj_X: The adjoint solution. A variable if the adjoint solution - has a single component, otherwise a :class:`Sequence` of variables. - Should not be modified. Subclasses may replace this argument with - `adj_x` if the adjoint solution has a single component. + :arg adj_X: The adjoint solution. A variable or a :class:`Sequence` of + variables. Should not be modified. :arg nl_deps: A :class:`Sequence` of variables defining values for non-linear dependencies. Should not be modified. :arg dep_Bs: A :class:`Mapping` whose items are `(dep_index, dep_B)`. @@ -529,18 +523,14 @@ def subtract_adjoint_derivative_actions(self, adj_X, nl_deps, dep_Bs): def adjoint_jacobian_solve(self, adj_X, nl_deps, B): """Compute an adjoint solution. - :arg adj_X: Either `None`, or a variable (if the adjoint solution has a - single component) or :class:`Sequence` of variables (otherwise) - defining the initial guess for an iterative solve. May be modified - or returned. Subclasses may replace this argument with `adj_x` if - the adjoint solution has a single component. + :arg adj_X: Either `None`, or a variable or :class:`Sequence` of + variables defining the initial guess for an iterative solve. May be + modified or returned. :arg nl_deps: A :class:`Sequence` of variables defining values for non-linear dependencies. Should not be modified. - :arg B: The right-hand-side. A variable (if the adjoint solution has a - single component) or :class:`Sequence` of variables (otherwise) - storing the value of the right-hand-side. May be modified or - returned. Subclasses may replace this argument with `b` if the - adjoint solution has a single component. + :arg B: The right-hand-side. A variable or :class:`Sequence` of + variables storing the value of the right-hand-side. May be modified + or returned. :returns: A variable or :class:`Sequence` of variables storing the value of the adjoint solution. May return `None` to indicate a value of zero. @@ -579,13 +569,10 @@ class ZeroAssignment(Equation): """ def __init__(self, X): - if is_var(X): - X = (X,) + X = packed(X) super().__init__(X, X, nl_deps=[], ic=False, adj_ic=False) def forward_solve(self, X, deps=None): - if is_var(X): - X = (X,) for x in X: var_zero(x) diff --git a/tlm_adjoint/equations.py b/tlm_adjoint/equations.py index bf663934..847dc3d3 100644 --- a/tlm_adjoint/equations.py +++ b/tlm_adjoint/equations.py @@ -1,6 +1,6 @@ from .interface import ( - check_space_types, check_space_types_conjugate_dual, - check_space_types_dual, is_var, var_assign, var_axpy, var_axpy_conjugate, + Packed, check_space_types, check_space_types_conjugate_dual, + check_space_types_dual, packed, var_assign, var_axpy, var_axpy_conjugate, var_dot, var_dtype, var_get_values, var_id, var_is_scalar, var_inner, var_local_size, var_new_conjugate_dual, var_replacement, var_scalar_value, var_set_values, var_zero) @@ -284,8 +284,8 @@ class MatrixActionRHS(RHS): """ def __init__(self, A, X): - if is_var(X): - X = (X,) + X_packed = Packed(X) + X = tuple(X_packed) if len(set(map(var_id, X))) != len(X): raise ValueError("Invalid dependency") @@ -305,6 +305,7 @@ def __init__(self, A, X): x_indices[nl_dep_ids[x_id]] = i super().__init__(nl_deps, nl_deps=nl_deps) + self._packed = X_packed.mapped(lambda x: None) self._A = A self._x_indices = x_indices @@ -314,31 +315,31 @@ def drop_references(self): super().drop_references() self._A = self._A._weak_alias + def _unpack(self, obj): + return self._packed.unpack(obj) + def add_forward(self, B, deps): - if is_var(B): - B = (B,) + B = packed(B) X = tuple(deps[j] for j in self._x_indices) self._A.forward_action(deps[:len(self._A.nonlinear_dependencies())], - X[0] if len(X) == 1 else X, - B[0] if len(B) == 1 else B, + self._unpack(X), + self._unpack(B), method="add") def subtract_adjoint_derivative_action(self, nl_deps, dep_index, adj_X, b): - if is_var(adj_X): - adj_X = (adj_X,) - + adj_X = packed(adj_X) N_A_nl_deps = len(self._A.nonlinear_dependencies()) if dep_index < N_A_nl_deps: X = tuple(nl_deps[j] for j in self._x_indices) self._A.adjoint_derivative_action( nl_deps[:N_A_nl_deps], dep_index, - X[0] if len(X) == 1 else X, - adj_X[0] if len(adj_X) == 1 else adj_X, + self._unpack(X), + self._unpack(adj_X), b, method="sub") if dep_index in self._x_indices: self._A.adjoint_action(nl_deps[:N_A_nl_deps], - adj_X[0] if len(adj_X) == 1 else adj_X, + self._unpack(adj_X), b, b_index=self._x_indices[dep_index], method="sub") @@ -348,16 +349,14 @@ def tangent_linear_rhs(self, tlm_map): X = tuple(deps[j] for j in self._x_indices) tlm_X = tuple(tlm_map[x] for x in X) - tlm_B = [MatrixActionRHS(self._A, tlm_X)] + tlm_B = [MatrixActionRHS(self._A, self._unpack(tlm_X))] if N_A_nl_deps > 0: tlm_b = self._A.tangent_linear_rhs(tlm_map, X) if tlm_b is None: pass - elif isinstance(tlm_b, RHS): - tlm_B.append(tlm_b) else: - tlm_B.extend(tlm_b) + tlm_B.extend(packed(tlm_b)) return tlm_B diff --git a/tlm_adjoint/fenics/backend_interface.py b/tlm_adjoint/fenics/backend_interface.py index 186c00af..2b7e2edb 100644 --- a/tlm_adjoint/fenics/backend_interface.py +++ b/tlm_adjoint/fenics/backend_interface.py @@ -1,11 +1,10 @@ from .backend import ( KrylovSolver, LUSolver, TestFunction, as_backend_type, backend_Constant, - backend_DirichletBC, backend_Function, backend_LocalSolver, - backend_ScalarType, backend_assemble, backend_assemble_system, - has_lu_solver_method) + backend_Function, backend_LocalSolver, backend_ScalarType, + backend_assemble, backend_assemble_system, has_lu_solver_method) from ..interface import ( - DEFAULT_COMM, check_space_type, check_space_types_conjugate_dual, space_eq, - space_new) + DEFAULT_COMM, check_space_type, check_space_types_conjugate_dual, packed, + space_eq, space_new) from .expr import eliminate_zeros from .parameters import update_parameters @@ -26,8 +25,8 @@ def assemble_matrix(form, bcs=None, *, form_compiler_parameters=None): if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} @@ -58,8 +57,8 @@ def assemble(form, tensor=None, bcs=None, *, check_space_type(tensor._tlm_adjoint__function, "conjugate_dual") if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) form = eliminate_zeros(form) b = backend_assemble(form, tensor=tensor, @@ -71,10 +70,15 @@ def assemble(form, tensor=None, bcs=None, *, def assemble_system(A_form, b_form, bcs=None, *, form_compiler_parameters=None): + if bcs is None: + bcs = () + else: + bcs = packed(bcs) + A_form = eliminate_zeros(A_form) b_form = eliminate_zeros(b_form) return backend_assemble_system( - A_form, b_form, bcs=bcs, + A_form, b_form, bcs=tuple(bcs), form_compiler_parameters=form_compiler_parameters) @@ -115,8 +119,8 @@ def assemble_linear_solver(A_form, b_form=None, bcs=None, *, linear_solver_parameters=None): if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} if linear_solver_parameters is None: diff --git a/tlm_adjoint/fenics/backend_patches.py b/tlm_adjoint/fenics/backend_patches.py index a98de33d..6dbf694d 100644 --- a/tlm_adjoint/fenics/backend_patches.py +++ b/tlm_adjoint/fenics/backend_patches.py @@ -6,7 +6,7 @@ backend_ScalarType, backend_Vector, backend_assemble, backend_project, backend_solve, cpp_Assembler, cpp_PETScVector, cpp_SystemAssembler) from ..interface import ( - DEFAULT_COMM, add_interface, comm_dup_cached, comm_parent, is_var, + DEFAULT_COMM, add_interface, comm_dup_cached, comm_parent, is_var, packed, new_space_id, new_var_id, space_eq, space_id, space_new, var_assign, var_comm, var_new, var_space, var_update_state) @@ -171,10 +171,8 @@ def Assembler_assemble(self, orig, orig_args, tensor, form): def SystemAssembler__init__(self, orig, orig_args, A_form, b_form, bcs=None): if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) else: - bcs = tuple(bcs) + bcs = packed(bcs) orig_args() @@ -834,10 +832,8 @@ def NonlinearVariationalProblem__init__( form_compiler_parameters=None): if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) else: - bcs = tuple(bcs) + bcs = packed(bcs) orig_args() @@ -907,10 +903,8 @@ def _project(v, V=None, bcs=None, mesh=None, function=None, if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) else: - bcs = tuple(bcs) + bcs = packed(bcs) solver_parameters_ = {"linear_solver": solver_type, "preconditioner": preconditioner_type} diff --git a/tlm_adjoint/fenics/caches.py b/tlm_adjoint/fenics/caches.py index ed8c9411..c632d6f3 100644 --- a/tlm_adjoint/fenics/caches.py +++ b/tlm_adjoint/fenics/caches.py @@ -3,10 +3,9 @@ """ from .backend import ( - Parameters, TrialFunction, backend_DirichletBC, backend_Function, - backend_LocalSolver) + Parameters, TrialFunction, backend_Function, backend_LocalSolver) from ..interface import ( - is_var, var_caches, var_id, var_is_cached, var_is_replacement, + is_var, packed, var_caches, var_id, var_is_cached, var_is_replacement, var_lock_state, var_replacement, var_space, var_state) from ..caches import Cache @@ -19,7 +18,7 @@ from .variables import ReplacementFunction from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Mapping, Sequence import itertools try: import ufl_legacy as ufl @@ -291,11 +290,10 @@ def parameters_key(parameters): key = [] for name in sorted(parameters.keys()): sub_parameters = parameters[name] - if isinstance(sub_parameters, (Parameters, dict)): + if isinstance(sub_parameters, (Parameters, Mapping)): key.append((name, parameters_key(sub_parameters))) - elif isinstance(sub_parameters, Sequence) \ - and not isinstance(sub_parameters, str): - key.append((name, tuple(sub_parameters))) + elif isinstance(sub_parameters, Sequence): + key.append((name, tuple(packed(sub_parameters)))) else: key.append((name, sub_parameters)) return tuple(key) @@ -333,8 +331,8 @@ def assemble(self, form, *, if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} if linear_solver_parameters is None: @@ -410,8 +408,8 @@ def linear_solver(self, form, *, if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} if linear_solver_parameters is None: diff --git a/tlm_adjoint/fenics/interpolation.py b/tlm_adjoint/fenics/interpolation.py index 7a2ea551..a2fb0f0d 100644 --- a/tlm_adjoint/fenics/interpolation.py +++ b/tlm_adjoint/fenics/interpolation.py @@ -5,7 +5,7 @@ Cell, Mesh, MeshEditor, Point, UserExpression, backend_Constant, backend_Function, backend_ScalarType, parameters) from ..interface import ( - check_space_type, comm_dup_cached, is_var, space_comm, space_eq, + check_space_type, comm_dup_cached, packed, space_comm, space_eq, var_assign, var_comm, var_get_values, var_id, var_inner, var_is_scalar, var_local_size, var_new, var_new_conjugate_dual, var_replacement, var_scalar_value, var_set_values) @@ -525,8 +525,7 @@ class PointInterpolation(Equation): def __init__(self, X, y, X_coords=None, *, P=None, tolerance=0.0): - if is_var(X): - X = (X,) + X = packed(X) for x in X: check_space_type(x, "primal") if not var_is_scalar(x): @@ -559,8 +558,6 @@ def __init__(self, X, y, X_coords=None, *, self._P_T = P.T def forward_solve(self, X, deps=None): - if is_var(X): - X = (X,) y = (self.dependencies() if deps is None else deps)[-1] check_space_type(y, "primal") @@ -577,8 +574,6 @@ def forward_solve(self, X, deps=None): var_assign(x, x_v[i]) def adjoint_derivative_action(self, nl_deps, dep_index, adj_X): - if is_var(adj_X): - adj_X = (adj_X,) if dep_index != len(self.X()): raise ValueError("Unexpected dep_index") diff --git a/tlm_adjoint/fenics/solve.py b/tlm_adjoint/fenics/solve.py index 86482a2e..564dcabe 100644 --- a/tlm_adjoint/fenics/solve.py +++ b/tlm_adjoint/fenics/solve.py @@ -5,8 +5,8 @@ Parameters, adjoint, backend_DirichletBC, backend_LocalSolver, backend_solve as solve, parameters) from ..interface import ( - check_space_type, var_axpy, var_copy, var_id, var_new_conjugate_dual, - var_replacement, var_update_caches, var_zero) + check_space_type, packed, var_axpy, var_copy, var_id, + var_new_conjugate_dual, var_replacement, var_update_caches, var_zero) from ..caches import CacheRef from ..equation import ZeroAssignment @@ -113,10 +113,8 @@ def __init__(self, eq, x, bcs=None, *, match_quadrature=None): if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) else: - bcs = tuple(bcs) + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} if solver_parameters is None: diff --git a/tlm_adjoint/firedrake/backend_interface.py b/tlm_adjoint/firedrake/backend_interface.py index 9ea21695..4934db33 100644 --- a/tlm_adjoint/firedrake/backend_interface.py +++ b/tlm_adjoint/firedrake/backend_interface.py @@ -1,9 +1,8 @@ from .backend import ( - LinearSolver, Tensor, backend_Cofunction, backend_DirichletBC, - backend_Function, backend_Matrix, backend_assemble, backend_solve, - extract_args) + LinearSolver, Tensor, backend_Cofunction, backend_Function, backend_Matrix, + backend_assemble, backend_solve, extract_args) from ..interface import ( - check_space_type, check_space_types_conjugate_dual, + check_space_type, check_space_types_conjugate_dual, packed, register_garbage_cleanup, space_eq, space_new) from ..patch import patch_method @@ -24,8 +23,8 @@ def _assemble(form, tensor=None, bcs=None, *, form_compiler_parameters=None, mat_type=None): if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} @@ -43,7 +42,7 @@ def _assemble(form, tensor=None, bcs=None, *, bc.apply(b.riesz_representation("l2")) else: b = backend_assemble( - form, tensor=tensor, bcs=bcs, + form, tensor=tensor, bcs=tuple(bcs), form_compiler_parameters=form_compiler_parameters, mat_type=mat_type) @@ -54,8 +53,8 @@ def _assemble_system(A_form, b_form=None, bcs=None, *, form_compiler_parameters=None, mat_type=None): if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} @@ -108,8 +107,8 @@ def assemble_matrix(form, bcs=None, *, form_compiler_parameters=None, mat_type=None): if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} @@ -173,8 +172,8 @@ def assemble_linear_solver(A_form, b_form=None, bcs=None, *, linear_solver_parameters=None): if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} if linear_solver_parameters is None: @@ -236,8 +235,8 @@ def solve(*args, **kwargs): check_space_type(x, "primal") if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} if solver_parameters is None: @@ -271,7 +270,7 @@ def solve(*args, **kwargs): "solver parameter") near_nullspace = tlm_adjoint_parameters["near_nullspace"] - return backend_solve(eq, x, bcs, J=J, Jp=Jp, M=M, + return backend_solve(eq, x, tuple(bcs), J=J, Jp=Jp, M=M, form_compiler_parameters=form_compiler_parameters, solver_parameters=solver_parameters, nullspace=nullspace, diff --git a/tlm_adjoint/firedrake/block_system.py b/tlm_adjoint/firedrake/block_system.py index 863c7238..4900404c 100644 --- a/tlm_adjoint/firedrake/block_system.py +++ b/tlm_adjoint/firedrake/block_system.py @@ -2,7 +2,7 @@ """ from .backend import TestFunction, backend_assemble, backend_DirichletBC -from ..interface import space_eq, var_axpy, var_inner, var_new +from ..interface import packed, space_eq, var_axpy, var_inner, var_new from ..block_system import ( BlockMatrix as _BlockMatrix, BlockNullspace, LinearSolver as _LinearSolver, @@ -11,7 +11,6 @@ from .backend_interface import assemble, matrix_multiply from .variables import Constant, Function -from collections.abc import Sequence import ufl __all__ = \ @@ -36,8 +35,7 @@ def apply_bcs(u, bcs): - if not isinstance(bcs, Sequence): - bcs = (bcs,) + bcs = packed(bcs) if len(bcs) > 0 and not isinstance(u.function_space(), type(bcs[0].function_space())): # noqa: E501 u_bc = u.riesz_representation("l2") else: @@ -158,11 +156,7 @@ class DirichletBCNullspace(Nullspace): """ def __init__(self, bcs, *, alpha=1.0): - if isinstance(bcs, Sequence): - bcs = tuple(bcs) - else: - bcs = (bcs,) - + bcs = packed(bcs) space = bcs[0].function_space() for bc in bcs: if not space_eq(bc.function_space(), space): diff --git a/tlm_adjoint/firedrake/caches.py b/tlm_adjoint/firedrake/caches.py index 152851f3..890be605 100644 --- a/tlm_adjoint/firedrake/caches.py +++ b/tlm_adjoint/firedrake/caches.py @@ -2,11 +2,9 @@ caching. """ -from .backend import ( - Parameters, TrialFunction, backend_DirichletBC, backend_Function, - complex_mode) +from .backend import Parameters, TrialFunction, backend_Function, complex_mode from ..interface import ( - is_var, var_caches, var_id, var_is_cached, var_is_replacement, + is_var, packed, var_caches, var_id, var_is_cached, var_is_replacement, var_lock_state, var_replacement, var_space, var_state) from ..caches import Cache @@ -19,7 +17,7 @@ from .variables import ReplacementFunction from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Mapping, Sequence import itertools import ufl @@ -345,11 +343,10 @@ def parameters_key(parameters): key = [] for name in sorted(parameters.keys()): sub_parameters = parameters[name] - if isinstance(sub_parameters, (Parameters, dict)): + if isinstance(sub_parameters, (Parameters, Mapping)): key.append((name, parameters_key(sub_parameters))) - elif isinstance(sub_parameters, Sequence) \ - and not isinstance(sub_parameters, str): - key.append((name, tuple(sub_parameters))) + elif isinstance(sub_parameters, Sequence): + key.append((name, tuple(packed(sub_parameters)))) else: key.append((name, sub_parameters)) return tuple(key) @@ -387,8 +384,8 @@ def assemble(self, form, *, if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} if linear_solver_parameters is None: @@ -464,8 +461,8 @@ def linear_solver(self, form, *, if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} if linear_solver_parameters is None: diff --git a/tlm_adjoint/firedrake/interpolation.py b/tlm_adjoint/firedrake/interpolation.py index f8896581..7f4dc5c7 100644 --- a/tlm_adjoint/firedrake/interpolation.py +++ b/tlm_adjoint/firedrake/interpolation.py @@ -5,7 +5,7 @@ FunctionSpace, Interpolator, TestFunction, VertexOnlyMesh, backend_Cofunction, backend_Constant, backend_Function) from ..interface import ( - check_space_type, comm_dup_cached, is_var, space_new, var_assign, var_comm, + check_space_type, comm_dup_cached, packed, space_new, var_assign, var_comm, var_copy, var_id, var_inner, var_is_scalar, var_new_conjugate_dual, var_replacement, var_scalar_value) @@ -176,9 +176,7 @@ class PointInterpolation(Equation): def __init__(self, X, y, X_coords=None, *, tolerance=None, _interp=None): - if is_var(X): - X = (X,) - + X = packed(X) for x in X: check_space_type(x, "primal") if not var_is_scalar(x): @@ -211,8 +209,6 @@ def __init__(self, X, y, X_coords=None, *, tolerance=None, self._interp = interp def forward_solve(self, X, deps=None): - if is_var(X): - X = (X,) y = (self.dependencies() if deps is None else deps)[-1] Xm = space_new(self._interp.V) @@ -225,8 +221,6 @@ def forward_solve(self, X, deps=None): X[index].assign(x_val) def adjoint_derivative_action(self, nl_deps, dep_index, adj_X): - if is_var(adj_X): - adj_X = (adj_X,) if dep_index != len(self.X()): raise ValueError("Unexpected dep_index") diff --git a/tlm_adjoint/firedrake/solve.py b/tlm_adjoint/firedrake/solve.py index 2448a59a..313e7587 100644 --- a/tlm_adjoint/firedrake/solve.py +++ b/tlm_adjoint/firedrake/solve.py @@ -1,9 +1,9 @@ """Finite element variational problem solution operations with Firedrake. """ -from .backend import adjoint, backend_DirichletBC, parameters +from .backend import adjoint, parameters from ..interface import ( - check_space_type, is_var, var_axpy, var_copy, var_id, var_new, + check_space_type, is_var, packed, var_axpy, var_copy, var_id, var_new, var_new_conjugate_dual, var_replacement, var_update_caches, var_zero) from ..caches import CacheRef @@ -43,8 +43,8 @@ class BCIndex(int): def unpack_bcs(bcs, *, deps=None): if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) + else: + bcs = packed(bcs) if deps is None: deps = [] dep_ids = set(map(var_id, deps)) @@ -146,10 +146,8 @@ def __init__(self, eq, x, bcs=None, *, match_quadrature=None): if bcs is None: bcs = () - elif isinstance(bcs, backend_DirichletBC): - bcs = (bcs,) else: - bcs = tuple(bcs) + bcs = packed(bcs) if form_compiler_parameters is None: form_compiler_parameters = {} if solver_parameters is None: diff --git a/tlm_adjoint/fixed_point.py b/tlm_adjoint/fixed_point.py index 1fdba521..66f230b2 100644 --- a/tlm_adjoint/fixed_point.py +++ b/tlm_adjoint/fixed_point.py @@ -1,5 +1,5 @@ from .interface import ( - is_var, no_space_type_checking, var_assign, var_axpy, var_copy, var_id, + no_space_type_checking, packed, var_assign, var_axpy, var_copy, var_id, var_inner, var_zero) from .adjoint import AdjointModelRHS @@ -56,8 +56,7 @@ def __init__(self, eqs, *, norm_sqs=None, adj_norm_sqs=None): if len(eqs) != len(norm_sqs): raise ValueError("Invalid squared norm callable(s)") for i, (eq, X_norm_sq) in enumerate(zip(eqs, norm_sqs)): - if callable(X_norm_sq): - X_norm_sq = (X_norm_sq,) + X_norm_sq = packed(X_norm_sq) if len(eq.X()) != len(X_norm_sq): raise ValueError("Invalid squared norm callable(s)") norm_sqs[i] = tuple(X_norm_sq) @@ -66,8 +65,7 @@ def __init__(self, eqs, *, norm_sqs=None, adj_norm_sqs=None): if len(eqs) != len(adj_norm_sqs): raise ValueError("Invalid squared norm callable(s)") for i, (eq, X_norm_sq) in enumerate(zip(eqs, adj_norm_sqs)): - if callable(X_norm_sq): - X_norm_sq = (X_norm_sq,) + X_norm_sq = packed(X_norm_sq) if len(eq.X()) != len(X_norm_sq): raise ValueError("Invalid squared norm callable(s)") adj_norm_sqs[i] = tuple(X_norm_sq) @@ -302,9 +300,6 @@ def drop_references(self): self._eqs = tuple(map(WeakAlias, self._eqs)) def forward_solve(self, X, deps=None): - if is_var(X): - X = (X,) - # Based on KrylovSolver parameters in FEniCS 2017.2.0 absolute_tolerance = self._solver_parameters["absolute_tolerance"] relative_tolerance = self._solver_parameters["relative_tolerance"] @@ -370,12 +365,8 @@ def forward_solve(self, X, deps=None): var_assign(x_0, x) def adjoint_jacobian_solve(self, adj_X, nl_deps, B): - if is_var(B): - B = (B,) if adj_X is None: adj_X = list(self.new_adj_X()) - elif is_var(adj_X): - adj_X = [adj_X] else: adj_X = list(adj_X) @@ -404,7 +395,7 @@ def adjoint_jacobian_solve(self, adj_X, nl_deps, B): if nonzero_initial_guess: for i, eq in enumerate(self._eqs): eq.subtract_adjoint_derivative_actions( - eq_adj_X[i][0] if len(eq_adj_X[i]) == 1 else eq_adj_X[i], + eq._unpack(eq_adj_X[i]), eq_nl_deps[i], dep_Bs[i]) else: for adj_x in adj_X: @@ -472,9 +463,6 @@ def adjoint_jacobian_solve(self, adj_X, nl_deps, B): return adj_X def subtract_adjoint_derivative_actions(self, adj_X, nl_deps, dep_Bs): - if is_var(adj_X): - adj_X = (adj_X,) - eq_deps = self.dependencies() eq_dep_Bs = tuple({} for _ in self._eqs) for dep_index, B in dep_Bs.items(): @@ -487,7 +475,7 @@ def subtract_adjoint_derivative_actions(self, adj_X, nl_deps, dep_Bs): eq_adj_X = tuple(adj_X[j] for j in self._eq_X_indices[i]) eq_nl_deps = tuple(nl_deps[j] for j in self._eq_nl_dep_indices[i]) eq.subtract_adjoint_derivative_actions( - eq_adj_X[0] if len(eq_adj_X) == 1 else eq_adj_X, + eq._unpack(eq_adj_X), eq_nl_deps, eq_dep_Bs[i]) def tangent_linear(self, tlm_map): diff --git a/tlm_adjoint/hessian.py b/tlm_adjoint/hessian.py index 039038d7..38fbe7da 100644 --- a/tlm_adjoint/hessian.py +++ b/tlm_adjoint/hessian.py @@ -1,5 +1,5 @@ from .interface import ( - check_space_types_conjugate_dual, is_var, var_axpy, var_copy, + Packed, check_space_types_conjugate_dual, packed, var_axpy, var_copy, var_copy_conjugate, var_is_cached, var_is_static, var_name, var_new, var_scalar_value) @@ -115,10 +115,10 @@ def __init__(self, forward, *, manager=None): @local_caches @restore_manager def compute_gradient(self, M, M0=None): - if is_var(M): - J, (dJ,) = self.compute_gradient( - (M,), M0=None if M0 is None else (M0,)) - return J, dJ + M_packed = Packed(M) + M = tuple(M_packed) + if M0 is not None: + M0 = packed(M0) set_manager(self._manager) reset_manager() @@ -140,15 +140,16 @@ def compute_gradient(self, M, M0=None): dJ = compute_gradient(J, M) reset_manager() - return J_val, dJ + return J_val, M_packed.unpack(dJ) @local_caches @restore_manager def action(self, M, dM, M0=None): - if is_var(M): - J_val, dJ_val, (ddJ,) = self.action( - (M,), (dM,), M0=None if M0 is None else (M0,)) - return J_val, dJ_val, ddJ + M_packed = Packed(M) + M = tuple(M_packed) + dM = packed(dM) + if M0 is not None: + M0 = packed(M0) set_manager(self._manager) reset_manager() @@ -178,7 +179,7 @@ def action(self, M, dM, M0=None): ddJ = compute_gradient(dJ, M) reset_manager() - return J_val, dJ_val, ddJ + return J_val, dJ_val, M_packed.unpack(ddJ) class GaussNewton(ABC): @@ -235,10 +236,11 @@ def action(self, M, dM, M0=None): variables depending on the type of `M`. """ - if is_var(M): - ddJ, = self.action( - (M,), (dM,), M0=None if M0 is None else (M0,)) - return ddJ + M_packed = Packed(M) + M = tuple(M_packed) + dM = packed(dM) + if M0 is not None: + M0 = packed(M0) manager, M, dM, X = self._setup_manager(M, dM, M0=M0) set_manager(manager) @@ -247,8 +249,7 @@ def action(self, M, dM, M0=None): tau_X = tuple(var_tlm(x, (M, dM)) for x in X) # conj[ R^{-1} J dM ] R_inv_tau_X = self._R_inv_action(*map(var_copy, tau_X)) - if is_var(R_inv_tau_X): - R_inv_tau_X = (R_inv_tau_X,) + R_inv_tau_X = packed(R_inv_tau_X) assert len(tau_X) == len(R_inv_tau_X) for tau_x, R_inv_tau_x in zip(tau_X, R_inv_tau_X): check_space_types_conjugate_dual(tau_x, R_inv_tau_x) @@ -271,8 +272,7 @@ def action(self, M, dM, M0=None): # Prior term: conj[ B^{-1} dM ] if self._B_inv_action is not None: B_inv_dM = self._B_inv_action(*map(var_copy, dM)) - if is_var(B_inv_dM): - B_inv_dM = (B_inv_dM,) + B_inv_dM = packed(B_inv_dM) assert len(dM) == len(B_inv_dM) for dm, B_inv_dm in zip(dM, B_inv_dM): check_space_types_conjugate_dual(dm, B_inv_dm) @@ -281,7 +281,7 @@ def action(self, M, dM, M0=None): var_axpy(ddJ[i], 1.0, B_inv_dm) reset_manager() - return ddJ + return M_packed.unpack(ddJ) def action_fn(self, m, m0=None): """Return a callable which can be used to compute Hessian actions using @@ -349,8 +349,7 @@ def _setup_manager(self, M, dM, M0=None): configure_tlm((M, dM), annotate=False) start_manager() X = self._forward(*M) - if is_var(X): - X = (X,) + X = packed(X) stop_manager() return self._manager, M, dM, X diff --git a/tlm_adjoint/hessian_system.py b/tlm_adjoint/hessian_system.py index be017a43..09ceedf7 100644 --- a/tlm_adjoint/hessian_system.py +++ b/tlm_adjoint/hessian_system.py @@ -1,13 +1,11 @@ from .interface import ( - var_axpy_conjugate, var_copy_conjugate, var_increment_state_lock, + packed, var_axpy_conjugate, var_copy_conjugate, var_increment_state_lock, var_space) from .block_system import ( BlockNullspace, LinearSolver, Matrix, NoneNullspace, TypedSpace) from .manager import manager_disabled -from collections.abc import Sequence - __all__ = \ [ "HessianMatrix", @@ -32,10 +30,7 @@ class HessianMatrix(Matrix): """ def __init__(self, H, M): - if isinstance(M, Sequence): - M = tuple(M) - else: - M = (M,) + M = packed(M) arg_space = tuple(TypedSpace(var_space(m)) for m in M) action_space = tuple(TypedSpace(var_space(m), space_type="dual") for m in M) # noqa: E501 @@ -94,8 +89,5 @@ def solve(self, u, b, **kwargs): :meth:`tlm_adjoint.block_system.LinearSolver.solve` method. """ - if isinstance(b, Sequence): - b = tuple(map(var_copy_conjugate, b)) - else: - b = var_copy_conjugate(b) - return super().solve(u, b, **kwargs) + b_conj = tuple(map(var_copy_conjugate, packed(b))) + super().solve(u, b_conj, **kwargs) diff --git a/tlm_adjoint/interface.py b/tlm_adjoint/interface.py index 0a6b46a0..af6d00cf 100644 --- a/tlm_adjoint/interface.py +++ b/tlm_adjoint/interface.py @@ -1207,8 +1207,7 @@ def var_update_caches(*X, value=None): var_check_state_lock(x) var_caches(x).update(x) else: - if is_var(value): - value = (value,) + value = packed(value) var_update_caches(*value) assert len(X) == len(value) for x, x_value in zip(X, value): @@ -1598,6 +1597,65 @@ def add_replacement_interface(replacement, x): "caches": var_caches(x)}) +class Packed(Sequence): + """A convenience class for converting objects to and from an immutable + :class:`Sequence`. Functionality based on the pyadjoint `Enlist` class + (see e.g. pyadjoint master branch revision + 908b6364e402a6776f2a378297beecaf2bebfb87). + """ + + def __init__(self, obj): + if isinstance(obj, Sequence) \ + and not is_space(obj) \ + and not is_var(obj) \ + and not isinstance(obj, str): + t = tuple(obj) + is_packed = False + else: + t = (obj,) + is_packed = True + + self._obj = obj + self._t = t + self._is_packed = is_packed + + def __eq__(self, other): + other = Packed(other) + return (tuple(self) == tuple(other) + and self.is_packed == other.is_packed) + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((tuple(self), self.is_packed)) + + def __len__(self): + return len(self._t) + + def __getitem__(self, key): + return self._t[key] + + @property + def is_packed(self): + return self._is_packed + + def unpack(self, obj): + obj_packed = Packed(obj) + if len(obj_packed) != len(self): + raise ValueError("Invalid length") + if self.is_packed: + obj, = obj_packed + return obj + + def mapped(self, fn): + return Packed(self.unpack(tuple(map(fn, self)))) + + +def packed(obj): + return tuple(Packed(obj)) + + @functools.singledispatch def subtract_adjoint_derivative_action(x, y): """Subtract an adjoint right-hand-side contribution defined by `y` from @@ -1617,10 +1675,8 @@ def subtract_adjoint_derivative_action(x, y): def register_subtract_adjoint_derivative_action(x_cls, y_cls, fn, *, replace=False): - if not isinstance(x_cls, Sequence): - x_cls = (x_cls,) - if not isinstance(y_cls, Sequence): - y_cls = (y_cls,) + x_cls = packed(x_cls) + y_cls = packed(y_cls) for x_cls, y_cls in itertools.product(x_cls, y_cls): if x_cls not in subtract_adjoint_derivative_action.registry: @functools.singledispatch @@ -1671,8 +1727,7 @@ def functional_term_eq(x, term): def register_functional_term_eq(term_cls, fn, *, replace=False): - if not isinstance(term_cls, Sequence): - term_cls = (term_cls,) + term_cls = packed(term_cls) for term_cls in term_cls: if term_cls in _functional_term_eq.registry and not replace: raise RuntimeError("Case already registered") diff --git a/tlm_adjoint/jax.py b/tlm_adjoint/jax.py index 865f6c58..85a142c2 100644 --- a/tlm_adjoint/jax.py +++ b/tlm_adjoint/jax.py @@ -1,6 +1,6 @@ from .interface import ( DEFAULT_COMM, SpaceInterface, VariableInterface, add_interface, - add_replacement_interface, comm_dup_cached, is_var, new_space_id, + add_replacement_interface, comm_dup_cached, packed, new_space_id, new_var_id, register_subtract_adjoint_derivative_action, subtract_adjoint_derivative_action_base, var_assign, var_axpy, var_comm, var_dtype, var_id, var_is_scalar, var_local_size, var_set_values, @@ -12,7 +12,6 @@ from .manager import ( annotation_enabled, manager as _manager, paused_manager, tlm_enabled) -from collections.abc import Sequence import contextlib import functools try: @@ -594,10 +593,8 @@ def fn(y0, y1, ...): """ def __init__(self, X, Y, fn, *, with_tlm=True, _forward_eq=None): - if is_var(X): - X = (X,) - if is_var(Y): - Y = (Y,) + X = packed(X) + Y = packed(Y) if len(X) != len(set(X)): raise ValueError("Duplicate solution") if len(Y) != len(set(Y)): @@ -613,8 +610,7 @@ def wrapped_fn(*args): if len(args) != n_Y: raise ValueError("Unexpected number of inputs") X_val = fn(*args) - if not isinstance(X_val, Sequence): - X_val = X_val, + X_val = packed(X_val) if len(X_val) != n_X: raise ValueError("Unexpected number of outputs") @@ -688,8 +684,7 @@ def solve(self, *, annotate=None, tlm=None): eq._annotate = True def forward_solve(self, X, deps=None): - if is_var(X): - X = (X,) + X = packed(X) if deps is None: deps = self.dependencies() Y = deps[len(X):] @@ -708,8 +703,7 @@ def adjoint_jacobian_solve(self, adj_X, nl_deps, B): return B def subtract_adjoint_derivative_actions(self, adj_X, nl_deps, dep_Bs): - if is_var(adj_X): - adj_X = (adj_X,) + adj_X = packed(adj_X) _, vjp = self._jax_reverse(*nl_deps) dF = vjp(tuple(adj_x.vector.conjugate() for adj_x in adj_X)) N_X = len(self.X()) diff --git a/tlm_adjoint/linear_equation.py b/tlm_adjoint/linear_equation.py index 42524930..c0ab2d60 100644 --- a/tlm_adjoint/linear_equation.py +++ b/tlm_adjoint/linear_equation.py @@ -1,6 +1,6 @@ from .interface import ( - conjugate_dual_space_type, is_var, var_id, var_new, var_new_conjugate_dual, - var_replacement, var_zero) + Packed, conjugate_dual_space_type, packed, var_id, var_new, + var_new_conjugate_dual, var_replacement, var_zero) from .equation import Equation, Referrer, ZeroAssignment @@ -42,10 +42,9 @@ class LinearEquation(Equation): """ def __init__(self, X, B, *, A=None, adj_type=None): - if is_var(X): - X = (X,) - if isinstance(B, RHS): - B = (B,) + X_packed = Packed(X) + X = tuple(X_packed) + B = packed(B) if adj_type is None: if A is None: adj_type = "conjugate_dual" @@ -119,7 +118,7 @@ def __init__(self, X, B, *, A=None, adj_type=None): del x_ids, dep_ids, nl_dep_ids super().__init__( - X, deps, nl_deps=nl_deps, + X_packed.unpack(X), deps, nl_deps=nl_deps, ic=A is not None and A.has_initial_condition(), adj_ic=A is not None and A.adjoint_has_initial_condition(), adj_type=adj_type) @@ -146,8 +145,7 @@ def drop_references(self): self._A = self._A._weak_alias def forward_solve(self, X, deps=None): - if is_var(X): - X = (X,) + X = packed(X) if deps is None: deps = self.dependencies() @@ -160,13 +158,13 @@ def forward_solve(self, X, deps=None): for m, x in enumerate(X)) for i, b in enumerate(self._B): - b.add_forward(B[0] if len(B) == 1 else B, + b.add_forward(self._unpack(B), [deps[j] for j in self._b_dep_indices[i]]) if self._A is not None: - self._A.forward_solve(X[0] if len(X) == 1 else X, + self._A.forward_solve(self._unpack(X), [deps[j] for j in self._A_dep_indices], - B[0] if len(B) == 1 else B) + self._unpack(B)) def adjoint_jacobian_solve(self, adj_X, nl_deps, B): if self._A is None: @@ -176,8 +174,7 @@ def adjoint_jacobian_solve(self, adj_X, nl_deps, B): adj_X, [nl_deps[j] for j in self._A_nl_dep_indices], B) def adjoint_derivative_action(self, nl_deps, dep_index, adj_X): - if is_var(adj_X): - adj_X = (adj_X,) + adj_X = packed(adj_X) eq_deps = self.dependencies() if dep_index < len(self.X()) or dep_index >= len(eq_deps): raise ValueError("Unexpected dep_index") @@ -195,7 +192,7 @@ def adjoint_derivative_action(self, nl_deps, dep_index, adj_X): b_nl_deps = tuple(nl_deps[j] for j in self._b_nl_dep_indices[i]) b.subtract_adjoint_derivative_action( b_nl_deps, b_dep_index, - adj_X[0] if len(adj_X) == 1 else adj_X, + self._unpack(adj_X), F) if self._A is not None and dep_id in self._A_nl_dep_ids: @@ -204,8 +201,8 @@ def adjoint_derivative_action(self, nl_deps, dep_index, adj_X): X = tuple(nl_deps[j] for j in self._A_x_indices) self._A.adjoint_derivative_action( A_nl_deps, A_nl_dep_index, - X[0] if len(X) == 1 else X, - adj_X[0] if len(adj_X) == 1 else adj_X, + self._unpack(X), + self._unpack(adj_X), F, method="add") return F @@ -216,26 +213,24 @@ def tangent_linear(self, tlm_map): if self._A is None: tlm_B = [] else: - tlm_B = self._A.tangent_linear_rhs(tlm_map, - X[0] if len(X) == 1 else X) + tlm_B = self._A.tangent_linear_rhs(tlm_map, self._unpack(X)) if tlm_B is None: tlm_B = [] - elif isinstance(tlm_B, RHS): - tlm_B = [tlm_B] + else: + tlm_B = list(packed(tlm_B)) for b in self._B: tlm_b = b.tangent_linear_rhs(tlm_map) if tlm_b is None: pass - elif isinstance(tlm_b, RHS): - tlm_B.append(tlm_b) else: - tlm_B.extend(tlm_b) + tlm_B.extend(packed(tlm_b)) if len(tlm_B) == 0: return ZeroAssignment([tlm_map[x] for x in self.X()]) else: - return LinearEquation([tlm_map[x] for x in self.X()], tlm_B, - A=self._A, adj_type=self.adj_X_type()) + return LinearEquation( + self._unpack([tlm_map[x] for x in self.X()]), tlm_B, + A=self._A, adj_type=self.adj_X_type()) class Matrix(Referrer): diff --git a/tlm_adjoint/markers.py b/tlm_adjoint/markers.py index 0e583ad8..ca230df5 100644 --- a/tlm_adjoint/markers.py +++ b/tlm_adjoint/markers.py @@ -1,4 +1,4 @@ -from .interface import is_var, var_new +from .interface import Packed, var_new from .equation import Equation @@ -28,10 +28,11 @@ class ControlsMarker(Equation): """ def __init__(self, M): - if is_var(M): - M = (M,) + M_packed = Packed(M) + M = tuple(M_packed) super(Equation, self).__init__() + self._packed = M_packed.mapped(lambda m: None) self._X = tuple(M) self._deps = tuple(M) self._nl_deps = () @@ -63,7 +64,7 @@ class FunctionalMarker(Equation): def __init__(self, J): # Extra variable allocation could be avoided J_ = var_new(J) - super().__init__([J_], [J_, J], nl_deps=[], ic=False, adj_ic=False) + super().__init__(J_, [J_, J], nl_deps=[], ic=False, adj_ic=False) def adjoint_derivative_action(self, nl_deps, dep_index, adj_x): if dep_index != 1: diff --git a/tlm_adjoint/optimization.py b/tlm_adjoint/optimization.py index d52bd104..85064816 100644 --- a/tlm_adjoint/optimization.py +++ b/tlm_adjoint/optimization.py @@ -1,9 +1,10 @@ from .interface import ( - comm_dup_cached, garbage_cleanup, is_var, paused_space_type_checking, - var_axpy, var_comm, var_copy, var_dtype, var_get_values, var_is_cached, - var_is_static, var_linf_norm, var_local_size, var_locked, var_new, - var_scalar_value, var_set_values, var_space, vars_assign, vars_axpy, - vars_copy, vars_inner, vars_new, vars_new_conjugate_dual) + Packed, comm_dup_cached, garbage_cleanup, packed, + paused_space_type_checking, var_axpy, var_comm, var_copy, var_dtype, + var_get_values, var_is_cached, var_is_static, var_linf_norm, + var_local_size, var_locked, var_new, var_scalar_value, var_set_values, + var_space, vars_assign, vars_axpy, vars_copy, vars_inner, vars_new, + vars_new_conjugate_dual) from .caches import clear_caches, local_caches from .hessian import GeneralHessian as Hessian @@ -55,8 +56,7 @@ def comm(self): @restore_manager def objective(self, M, *, force=False): - if is_var(M): - M = (M,) + M = packed(M) if self._M is not None and len(M) != len(self._M): raise ValueError("Invalid control") for m in M: @@ -109,10 +109,8 @@ def objective(self, M, *, @restore_manager def gradient(self, M): - if is_var(M): - dJ, = self.gradient((M,)) - return dJ - + M_packed = Packed(M) + M = tuple(M_packed) set_manager(self._manager) _ = self.objective(M, force=self._manager._cp_schedule.is_exhausted) @@ -121,13 +119,12 @@ def gradient(self, M): for dJ_i in dJ: if not issubclass(var_dtype(dJ_i), np.floating): raise ValueError("Invalid dtype") - return dJ + return M_packed.unpack(dJ) def hessian_action(self, M, dM): - if is_var(M): - ddJ, = self.hessian_action((M,), (dM,)) - return ddJ - + M_packed = Packed(M) + M = tuple(M_packed) + dM = packed(dM) for m in M: if not issubclass(var_dtype(m), np.floating): raise ValueError("Invalid dtype") @@ -141,7 +138,7 @@ def hessian_action(self, M, dM): for ddJ_i in ddJ: if not issubclass(var_dtype(ddJ_i), np.floating): raise ValueError("Invalid dtype") - return ddJ + return M_packed.unpack(ddJ) @contextlib.contextmanager @@ -178,11 +175,8 @@ def minimize_scipy(forward, M0, *, :func:`scipy.optimize.minimize`. """ - if is_var(M0): - (M,), return_value = minimize_scipy(forward, (M0,), - manager=manager, **kwargs) - return M, return_value - + M0_packed = Packed(M0) + M0 = tuple(M0_packed) if manager is None: manager = _manager() manager = manager.new() @@ -295,7 +289,7 @@ def hessp_bcast(x, p): if not return_value.success: raise RuntimeError("Convergence failure") - return M, return_value + return M0_packed.unpack(M), return_value def conjugate_dual_identity_action(*X): @@ -312,8 +306,7 @@ def wrapped_action(M): def M(*X): with var_locked(*X): M_X = M_arg(*X) - if is_var(M_X): - M_X = (M_X,) + M_X = packed(M_X) if len(M_X) != len(X): raise ValueError("Incompatible shape") return vars_copy(M_X) @@ -343,10 +336,8 @@ def append(self, S, Y, S_inner_Y): that used in the line search can be supplied. """ - if is_var(S): - S = (S,) - if is_var(Y): - Y = (Y,) + S = packed(S) + Y = packed(Y) if len(S) != len(Y): raise ValueError("Incompatible shape") for s in S: @@ -394,9 +385,8 @@ def inverse_action(self, X, *, result. """ - if is_var(X): - X = (X,) - X = vars_copy(X) + X_packed = Packed(X) + X = tuple(X_packed.mapped(var_copy)) if H_0_action is None: H_0_action = wrapped_action(conjugate_dual_identity_action) @@ -420,7 +410,7 @@ def inverse_action(self, X, *, beta = rho * vars_inner(R, Y) vars_axpy(R, alpha - beta, S) - return R[0] if len(R) == 1 else R + return X_packed.unpack(R) def line_search(F, Fp, X, minus_P, *, @@ -482,16 +472,14 @@ def F(*X): return F_arg(*X) Fp = wrapped_action(Fp) - if is_var(X): - X = (X,) - if is_var(minus_P): - minus_P = (minus_P,) + X = packed(X) + minus_P = packed(minus_P) if old_F_val is None: old_F_val = F(*X) if old_Fp_val is None: old_Fp_val = Fp(*X) - elif is_var(old_Fp_val): - old_Fp_val = (old_Fp_val,) + else: + old_Fp_val = packed(old_Fp_val) if comm is None: comm = var_comm(X[0]) @@ -732,14 +720,13 @@ def Fp(*X): Fp_calls += 1 with var_locked(*X): Fp_val = Fp_arg(*X) - if is_var(Fp_val): - Fp_val = (Fp_val,) + Fp_val = packed(Fp_val) if len(Fp_val) != len(X): raise ValueError("Incompatible shape") return vars_copy(Fp_val) - if is_var(X0): - X0 = (X0,) + X0_packed = Packed(X0) + X0 = tuple(X0_packed) if converged is None: def converged(it, F_old, F_new, X_new, G_new, S, Y): @@ -750,10 +737,8 @@ def converged(it, F_old, F_new, X_new, G_new, S, Y): @wraps(converged_arg) def converged(it, F_old, F_new, X_new, G_new, S, Y): return converged_arg(it, F_old, F_new, - X_new[0] if len(X_new) == 1 else X_new, - G_new[0] if len(G_new) == 1 else G_new, - S[0] if len(S) == 1 else S, - Y[0] if len(Y) == 1 else Y) + X0_packed.unpack(X_new), X0_packed.unpack(G_new), # noqa: E501 + X0_packed.unpack(S), X0_packed.unpack(Y)) if (H_0_action is None and M_inv_action is None) and M_action is not None: raise TypeError("If M_action is supplied, then H_0_action or " @@ -825,8 +810,7 @@ def M_inv_norm_sq(X): minus_P = hessian_approx.inverse_action( old_Fp_val, H_0_action=H_0_action, theta=theta) - if is_var(minus_P): - minus_P = (minus_P,) + minus_P = packed(minus_P) old_Fp_val_rank0 = -vars_inner(minus_P, old_Fp_val) alpha, new_F_val, new_Fp_val = line_search( F, Fp, X, minus_P, c1=c1, c2=c2, @@ -880,7 +864,7 @@ def M_inv_norm_sq(X): del new_F_val, new_Fp_val, new_Fp_val_rank0 old_Fp_norm_sq = M_inv_norm_sq(old_Fp_val) - return X[0] if len(X) == 1 else X, (it, F_calls, Fp_calls, hessian_approx) + return X0_packed.unpack(X), (it, F_calls, Fp_calls, hessian_approx) @local_caches @@ -900,11 +884,8 @@ def minimize_l_bfgs(forward, M0, *, :func:`.l_bfgs` documentation. """ - if is_var(M0): - (x,), optimization_data = minimize_l_bfgs( - forward, (M0,), - m=m, manager=manager, **kwargs) - return x, optimization_data + M0_packed = Packed(M0) + M0 = tuple(M0_packed) for m0 in M0: if not issubclass(var_dtype(m0), np.floating): @@ -920,9 +901,7 @@ def minimize_l_bfgs(forward, M0, *, lambda *M: J_hat.objective(M), lambda *M: J_hat.gradient(M), M0, m=m, comm=J_hat.comm, **kwargs) - if is_var(X): - X = (X,) - return X, optimization_data + return M0_packed.unpack(X), optimization_data def petsc_tao(J_hat, M, *, solver_parameters=None, @@ -1098,8 +1077,7 @@ class TAOSolver: def __init__(self, forward, M, *, solver_parameters=None, H_0_action=None, M_inv_action=None, manager=None): - if is_var(M): - M = (M,) + M = packed(M) if manager is None: manager = _manager() manager = manager.new() @@ -1133,9 +1111,7 @@ def solve(self, M): :arg M: Defines the solution. """ - if is_var(M): - M = (M,) - + M = packed(M) x = PETScVec(self._vec_interface) x.to_petsc(M) self.tao.solve(x.vec) @@ -1155,11 +1131,9 @@ def minimize_tao(forward, M0, *args, **kwargs): Remaining arguments are passed to the :class:`.TAOSolver` constructor. """ - if is_var(M0): - m, = minimize_tao(forward, (M0,), *args, **kwargs) - return m - + M0_packed = Packed(M0) + M0 = tuple(M0_packed) M = tuple(var_copy(m0, static=var_is_static(m0), cache=var_is_cached(m0)) for m0 in M0) TAOSolver(forward, M, *args, **kwargs).solve(M) - return M + return M0_packed.unpack(M) diff --git a/tlm_adjoint/tangent_linear.py b/tlm_adjoint/tangent_linear.py index 962025f4..3c3ca3f3 100644 --- a/tlm_adjoint/tangent_linear.py +++ b/tlm_adjoint/tangent_linear.py @@ -1,5 +1,5 @@ from .interface import ( - check_space_types, is_var, var_id, var_is_replacement, var_name, + check_space_types, is_var, packed, var_id, var_is_replacement, var_name, var_new_tangent_linear) from .alias import gc_disabled @@ -18,15 +18,8 @@ def tlm_key(M, dM): - if is_var(M): - M = (M,) - else: - M = tuple(M) - if is_var(dM): - dM = (dM,) - else: - dM = tuple(dM) - + M = packed(M) + dM = packed(dM) if any(map(var_is_replacement, M)): raise ValueError("Invalid tangent-linear") if any(map(var_is_replacement, dM)): diff --git a/tlm_adjoint/tlm_adjoint.py b/tlm_adjoint/tlm_adjoint.py index fb3cd1c1..dfbd140f 100644 --- a/tlm_adjoint/tlm_adjoint.py +++ b/tlm_adjoint/tlm_adjoint.py @@ -1,5 +1,5 @@ from .interface import ( - DEFAULT_COMM, comm_dup_cached, garbage_cleanup, is_var, var_assign, + DEFAULT_COMM, Packed, comm_dup_cached, garbage_cleanup, var_assign, var_copy, var_id, var_is_replacement, var_is_scalar, var_name) from .adjoint import AdjointCache, AdjointModelRHS, TransposeComputationalGraph @@ -1138,35 +1138,6 @@ def callback(J_i, n, i, eq, adj_X): with respect to the :math:`j` th control. """ - if is_var(M): - if is_var(Js): - ((dJ,),) = self.compute_gradient( - (Js,), (M,), callback=callback, - prune_forward=prune_forward, prune_adjoint=prune_adjoint, - prune_replay=prune_replay, - cache_adjoint_degree=cache_adjoint_degree, - store_adjoint=store_adjoint, - adj_ics=None if adj_ics is None else (adj_ics,)) - return dJ - else: - dJs = self.compute_gradient( - Js, (M,), callback=callback, - prune_forward=prune_forward, prune_adjoint=prune_adjoint, - prune_replay=prune_replay, - cache_adjoint_degree=cache_adjoint_degree, - store_adjoint=store_adjoint, - adj_ics=adj_ics) - return tuple(dJ for dJ, in dJs) - elif is_var(Js): - dJ, = self.compute_gradient( - (Js,), M, callback=callback, - prune_forward=prune_forward, prune_adjoint=prune_adjoint, - prune_replay=prune_replay, - cache_adjoint_degree=cache_adjoint_degree, - store_adjoint=store_adjoint, - adj_ics=None if adj_ics is None else (adj_ics,)) - return dJ - set_manager(self) self.finalize() @@ -1174,12 +1145,14 @@ def callback(J_i, n, i, eq, adj_X): raise RuntimeError("Invalid checkpointing state") # Functionals - Js = tuple(Js) + Js_packed = Packed(Js) + Js = tuple(Js_packed) if not all(map(var_is_scalar, Js)): raise ValueError("Functional must be a scalar variable") # Controls - M = tuple(M) + M_packed = Packed(M) + M = tuple(M_packed) # Derivatives dJ = [None for J in Js] @@ -1312,17 +1285,10 @@ def callback(J_i, n, i, eq, adj_X): self._adj_cache.cache(J_i, n, i, adj_X, copy=True, store=store_adjoint) + # Diagnostic callback if callback is not None: - # Diagnostic callback - if adj_X is None: - callback(J_i, n, i, eq, - None) - elif len(adj_X) == 1: - callback(J_i, n, i, eq, - var_copy(adj_X[0])) - else: - callback(J_i, n, i, eq, - tuple(map(var_copy, adj_X))) + callback(J_i, n, i, eq, + None if adj_X is None else eq._unpack(adj_X)) if n == -1: assert i == 0 @@ -1372,7 +1338,7 @@ def action_pass(cp_action): del action garbage_cleanup(self._comm) - return tuple(dJ) + return Js_packed.unpack(tuple(M_packed.unpack(dJ_) for dJ_ in dJ)) set_manager(EquationManager()) diff --git a/tlm_adjoint/verification.py b/tlm_adjoint/verification.py index cf368979..d4389103 100644 --- a/tlm_adjoint/verification.py +++ b/tlm_adjoint/verification.py @@ -84,7 +84,7 @@ """ from .interface import ( - garbage_cleanup, is_var, space_comm, var_assign, var_axpy, var_copy, + garbage_cleanup, packed, space_comm, var_assign, var_axpy, var_copy, var_dtype, var_is_cached, var_is_static, var_local_size, var_name, var_new, var_set_values, vars_inner, vars_linf_norm, var_scalar_value) @@ -169,15 +169,13 @@ def taylor_test(forward, M, J_val, *, dJ=None, ddJ=None, seed=1.0e-2, dM=None, close to 3 if `ddJ` is supplied. """ - if is_var(M): - if dJ is not None: - dJ = (dJ,) - if dM is not None: - dM = (dM,) - if M0 is not None: - M0 = (M0,) - return taylor_test(forward, (M,), J_val, dJ=dJ, ddJ=ddJ, seed=seed, - dM=dM, M0=M0, size=size) + M = packed(M) + if dJ is not None: + dJ = packed(dJ) + if dM is not None: + dM = packed(dM) + if M0 is not None: + M0 = packed(M0) logger = logging.getLogger("tlm_adjoint.verification") forward = wrapped_forward(forward) @@ -277,11 +275,9 @@ def taylor_test_tlm(forward, M, tlm_order, *, seed=1.0e-2, dMs=None, size=5, verification this should be close to 2. """ - if is_var(M): - if dMs is not None: - dMs = tuple((dM,) for dM in dMs) - return taylor_test_tlm(forward, (M,), tlm_order, seed=seed, dMs=dMs, - size=size, manager=manager) + M = packed(M) + if dMs is not None: + dMs = tuple(map(packed, dMs)) logger = logging.getLogger("tlm_adjoint.verification") forward = wrapped_forward(forward) @@ -389,12 +385,9 @@ def taylor_test_tlm_adjoint(forward, M, adjoint_order, *, seed=1.0e-2, verification this should be close to 2. """ - if is_var(M): - if dMs is not None: - dMs = tuple((dM,) for dM in dMs) - return taylor_test_tlm_adjoint( - forward, (M,), adjoint_order, seed=seed, dMs=dMs, size=size, - manager=manager) + M = packed(M) + if dMs is not None: + dMs = tuple(map(packed, dMs)) forward = wrapped_forward(forward) if manager is None: From 22da9bbf974ea243141e14c4d760f2e25080e76e Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 28 May 2024 10:28:44 +0100 Subject: [PATCH 2/4] Tidying --- tlm_adjoint/block_system.py | 2 +- tlm_adjoint/cached_hessian.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tlm_adjoint/block_system.py b/tlm_adjoint/block_system.py index 3bd31987..a6164cb1 100644 --- a/tlm_adjoint/block_system.py +++ b/tlm_adjoint/block_system.py @@ -263,7 +263,7 @@ def __init__(self, spaces): spaces = tuple_sub(flattened_spaces, spaces) super().__init__(tuple(space.space for space in flattened_spaces)) - self._spaces = tuple_sub(spaces, spaces) + self._spaces = spaces self._flattened_spaces = flattened_spaces def __len__(self): diff --git a/tlm_adjoint/cached_hessian.py b/tlm_adjoint/cached_hessian.py index 907531fa..3e9e4f97 100644 --- a/tlm_adjoint/cached_hessian.py +++ b/tlm_adjoint/cached_hessian.py @@ -222,7 +222,7 @@ def __init__(self, X, R_inv_action, B_inv_action=None, *, cache_adjoint=False) GaussNewton.__init__( self, R_inv_action, B_inv_action=B_inv_action) - self._X = X + self._X = tuple(X) def _setup_manager(self, M, dM, M0=None, *, annotate_tlm=False, solve_tlm=True): From e69455ad353678479ce9b901823d2aad6d958772 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 28 May 2024 10:45:30 +0100 Subject: [PATCH 3/4] Fix --- tlm_adjoint/block_system.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tlm_adjoint/block_system.py b/tlm_adjoint/block_system.py index a6164cb1..66a67b9d 100644 --- a/tlm_adjoint/block_system.py +++ b/tlm_adjoint/block_system.py @@ -70,8 +70,8 @@ """ from .interface import ( - comm_dup_cached, packed, space_comm, space_default_space_type, space_eq, - space_new, var_assign, var_locked, var_zero) + Packed, comm_dup_cached, packed, space_comm, space_default_space_type, + space_eq, space_new, var_assign, var_locked, var_zero) from .manager import manager_disabled from .petsc import ( PETScOptions, PETScVec, PETScVecInterface, attach_destroy_finalizer) @@ -113,7 +113,7 @@ def expand(e): q = deque(map(expand, iterable)) while len(q) > 0: - e = packed(q.popleft()) + e = Packed(q.popleft()) if e.is_packed: e, = e yield e @@ -125,7 +125,7 @@ def tuple_sub(iterable, sequence): iterator = iter_sub(iterable) def tuple_sub(iterator, value): - value = packed(value) + value = Packed(value) if value.is_packed: return next(iterator) else: @@ -859,18 +859,20 @@ def solve(self, u, b, *, the solution. """ - u = packed(u) - b = packed(b) + u_packed = Packed(u) + b_packed = Packed(b) + u = tuple(u_packed) + b = tuple(b_packed) pc_fn = self._pc_fn - if u.is_packed: + if u_packed.is_packed: pc_fn_u = pc_fn def pc_fn(u, b): u, = tuple(iter_sub(u)) pc_fn_u(u, b) - if b.is_packed: + if b_packed.is_packed: pc_fn_b = pc_fn def pc_fn(u, b): From a49bf1e042973f84b5c72a1eff56686116392aed Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 28 May 2024 12:07:16 +0100 Subject: [PATCH 4/4] Fix --- tlm_adjoint/hessian_system.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tlm_adjoint/hessian_system.py b/tlm_adjoint/hessian_system.py index 09ceedf7..70b07e42 100644 --- a/tlm_adjoint/hessian_system.py +++ b/tlm_adjoint/hessian_system.py @@ -1,6 +1,6 @@ from .interface import ( - packed, var_axpy_conjugate, var_copy_conjugate, var_increment_state_lock, - var_space) + Packed, packed, var_axpy_conjugate, var_copy_conjugate, + var_increment_state_lock, var_space) from .block_system import ( BlockNullspace, LinearSolver, Matrix, NoneNullspace, TypedSpace) @@ -89,5 +89,5 @@ def solve(self, u, b, **kwargs): :meth:`tlm_adjoint.block_system.LinearSolver.solve` method. """ - b_conj = tuple(map(var_copy_conjugate, packed(b))) - super().solve(u, b_conj, **kwargs) + b_conj = Packed(b).mapped(var_copy_conjugate) + super().solve(u, b_conj.unpack(b_conj), **kwargs)