From 6bf74a16815ac890d28d8463697146f48583a58b Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 31 Oct 2023 12:58:22 +0000 Subject: [PATCH 1/3] Use variable state locking in the CachedHessian and CachedGaussNewton classes --- tlm_adjoint/cached_hessian.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tlm_adjoint/cached_hessian.py b/tlm_adjoint/cached_hessian.py index bdd29dc2..3f7c85d1 100644 --- a/tlm_adjoint/cached_hessian.py +++ b/tlm_adjoint/cached_hessian.py @@ -2,7 +2,8 @@ # -*- coding: utf-8 -*- from .interface import ( - StateLockDictionary, var_id, var_new, var_scalar_value, var_state) + StateLockDictionary, var_check_state_lock, var_id, + var_increment_state_lock, var_new, var_scalar_value) from .caches import clear_caches from .hessian import GaussNewton, Hessian @@ -146,10 +147,11 @@ class CachedHessian(Hessian, HessianOptimization): """ def __init__(self, J, *, manager=None, cache_adjoint=True): + var_increment_state_lock(J, self) + HessianOptimization.__init__(self, manager=manager, cache_adjoint=cache_adjoint) Hessian.__init__(self) - self._J_state = var_state(J) self._J = J @restore_manager @@ -160,8 +162,7 @@ def compute_gradient(self, M, M0=None): M0=None if M0 is None else (M0,)) return J_val, dJ - if var_state(self._J) != self._J_state: - raise RuntimeError("State has changed") + var_check_state_lock(self._J) dM = tuple(map(var_new, M)) manager, M, dM = self._setup_manager(M, dM, M0=M0, solve_tlm=False) @@ -185,8 +186,7 @@ def action(self, M, dM, M0=None): M0=None if M0 is None else (M0,)) return J_val, dJ_val, ddJ - if var_state(self._J) != self._J_state: - raise RuntimeError("State has changed") + var_check_state_lock(self._J) manager, M, dM = self._setup_manager(M, dM, M0=M0, solve_tlm=True) set_manager(manager) @@ -220,12 +220,14 @@ def __init__(self, X, R_inv_action, B_inv_action=None, *, if not isinstance(X, Sequence): X = (X,) + for x in X: + var_increment_state_lock(x, self) + HessianOptimization.__init__(self, manager=manager, cache_adjoint=False) GaussNewton.__init__( self, R_inv_action, B_inv_action=B_inv_action) self._X = tuple(X) - self._X_state = tuple(map(var_state, X)) def _setup_manager(self, M, dM, M0=None, *, annotate_tlm=False, solve_tlm=True): @@ -235,7 +237,7 @@ def _setup_manager(self, M, dM, M0=None, *, return manager, M, dM, self._X def action(self, M, dM, M0=None): - if tuple(map(var_state, self._X)) != self._X_state: - raise RuntimeError("State has changed") + for x in self._X: + var_check_state_lock(x) return GaussNewton.action(self, M, dM, M0=M0) From 7ce76ff38d5a4d260f9a0808a3f56cdd21a32cea Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 31 Oct 2023 12:58:41 +0000 Subject: [PATCH 2/3] FEniCS/Firedrake backends: Use variable state locking with DirichletBC objects --- tests/fenics/test_interface.py | 73 +++++++++++++++++++ tests/firedrake/test_interface.py | 81 ++++++++++++++++++++++ tlm_adjoint/fenics/backend_overrides.py | 24 ++++++- tlm_adjoint/fenics/equations.py | 17 ++++- tlm_adjoint/fenics/functions.py | 39 +++-------- tlm_adjoint/firedrake/backend_overrides.py | 22 +++++- tlm_adjoint/firedrake/equations.py | 17 ++++- tlm_adjoint/firedrake/functions.py | 47 +++---------- tlm_adjoint/interface.py | 9 ++- tlm_adjoint/override.py | 22 ++++-- 10 files changed, 268 insertions(+), 83 deletions(-) diff --git a/tests/fenics/test_interface.py b/tests/fenics/test_interface.py index ed3321e9..9ef360b5 100644 --- a/tests/fenics/test_interface.py +++ b/tests/fenics/test_interface.py @@ -2,10 +2,13 @@ # -*- coding: utf-8 -*- from fenics import * +from tlm_adjoint.alias import WeakAlias +from tlm_adjoint.interface import VariableStateChangeError from tlm_adjoint.fenics import * from .test_base import * +import gc import pytest pytestmark = pytest.mark.skipif( @@ -103,3 +106,73 @@ def test_scalar(x, ref): test_scalar(Constant(0.0), True) test_scalar(Constant((0.0, 0.0)), False) test_scalar(Function(space), False) + + +@pytest.mark.fenics +@seed_test +def test_DirichletBC_locking(setup_test, test_leaks): + def new_bc(space_cls): + mesh = UnitSquareMesh(10, 10) + space = space_cls(mesh, "Lagrange", 1) + bc_value = Function(space, name="bc_value") + bc = DirichletBC(space, bc_value, "on_boundary", static=True) + return space, bc + + def new_hbc(space_cls): + mesh = UnitSquareMesh(10, 10) + space = space_cls(mesh, "Lagrange", 1) + bc = HomogeneousDirichletBC(space, "on_boundary") + return space, bc + + for space_cls in (FunctionSpace, + VectorFunctionSpace, + TensorFunctionSpace): + _, bc = new_bc(space_cls) + try: + bc.homogenize() + assert False + except VariableStateChangeError: + pass + del bc + + space, bc = new_bc(space_cls) + try: + bc.set_value(Function(space)) + assert False + except VariableStateChangeError: + pass + del space, bc + + _, bc = new_hbc(space_cls) + bc.homogenize() + del bc + + def new_eq(): + mesh = UnitSquareMesh(10, 10) + space = FunctionSpace(mesh, "Lagrange", 1) + bc_value = Function(space, name="bc_value") + bc = DirichletBC(space, bc_value, "on_boundary") + test, trial = TestFunction(space), TrialFunction(space) + u = Function(space, name="u") + eq = EquationSolver(inner(trial, test) * dx + == inner(Constant(1.0), test) * dx, u, bc) + return space, bc, eq + + _, bc, eq = new_eq() + try: + bc.homogenize() + assert False + except VariableStateChangeError: + pass + del bc, eq + + _, bc, eq = new_eq() + eq = WeakAlias(eq) + gc.collect() + garbage_cleanup() + try: + bc.homogenize() + assert False + except VariableStateChangeError: + pass + del bc, eq diff --git a/tests/firedrake/test_interface.py b/tests/firedrake/test_interface.py index ee57231e..5f4ba867 100644 --- a/tests/firedrake/test_interface.py +++ b/tests/firedrake/test_interface.py @@ -2,10 +2,13 @@ # -*- coding: utf-8 -*- from firedrake import * +from tlm_adjoint.alias import WeakAlias +from tlm_adjoint.interface import VariableStateChangeError from tlm_adjoint.firedrake import * from .test_base import * +import gc import pytest import ufl @@ -153,3 +156,81 @@ def test_scalar(x, ref): test_scalar(Constant((0.0, 0.0)), False) test_scalar(Function(space), False) test_scalar(Cofunction(space.dual()), False) + + +@pytest.mark.firedrake +@seed_test +def test_DirichletBC_locking(setup_test, test_leaks): + def new_bc(space_cls): + mesh = UnitSquareMesh(10, 10) + space = space_cls(mesh, "Lagrange", 1) + bc_value = Function(space, name="bc_value") + bc = DirichletBC(space, bc_value, "on_boundary", static=True) + return space, bc + + def new_hbc(space_cls): + mesh = UnitSquareMesh(10, 10) + space = space_cls(mesh, "Lagrange", 1) + bc = HomogeneousDirichletBC(space, "on_boundary") + return space, bc + + for space_cls in (FunctionSpace, + VectorFunctionSpace, + TensorFunctionSpace): + _, bc = new_bc(space_cls) + try: + bc.homogenize() + assert False + except VariableStateChangeError: + pass + del bc + + space, bc = new_bc(space_cls) + try: + bc.set_value(Function(space)) + assert False + except VariableStateChangeError: + pass + del space, bc + + space, bc = new_bc(space_cls) + try: + bc.function_arg = Function(space) + assert False + except VariableStateChangeError: + pass + del space, bc + + _, bc = new_hbc(space_cls) + bc.homogenize() + del bc + + def new_eq(): + mesh = UnitSquareMesh(10, 10) + space = FunctionSpace(mesh, "Lagrange", 1) + bc_value = Function(space, name="bc_value") + bc = DirichletBC(space, bc_value, "on_boundary") + test, trial = TestFunction(space), TrialFunction(space) + u = Function(space, name="u") + eq = EquationSolver(inner(trial, test) * dx + == inner(Constant(1.0), test) * dx, u, bc) + return space, bc, eq + + _, bc, eq = new_eq() + try: + bc.homogenize() + assert False + except VariableStateChangeError: + pass + del bc, eq + + _, bc, eq = new_eq() + eq = WeakAlias(eq) + gc.collect() + garbage_cleanup() + try: + bc.homogenize() + assert False + except VariableStateChangeError: + pass + del bc, eq diff --git a/tlm_adjoint/fenics/backend_overrides.py b/tlm_adjoint/fenics/backend_overrides.py index b20d9fbc..93b964dd 100644 --- a/tlm_adjoint/fenics/backend_overrides.py +++ b/tlm_adjoint/fenics/backend_overrides.py @@ -8,8 +8,8 @@ backend_Vector, backend_project, backend_solve, cpp_Assembler, cpp_SystemAssembler, parameters) from ..interface import ( - space_id, space_new, var_assign, var_comm, var_new, var_space, - var_update_state) + VariableStateChangeError, is_var, space_id, space_new, var_assign, + var_comm, var_new, var_space, var_state_is_locked, var_update_state) from .backend_code_generator_interface import ( copy_parameters_dict, update_parameters_dict) @@ -316,8 +316,26 @@ def project(orig, orig_args, v, V=None, bcs=None, mesh=None, function=None, return return_value +@override_method(backend_DirichletBC, "homogenize") +def DirichletBC_homogenize(self, orig, orig_args, *args, **kwargs): + bc_value = getattr(self, "_tlm_adjoint__bc_value", None) + if is_var(bc_value) and var_state_is_locked(bc_value): + raise VariableStateChangeError("Cannot change DirichletBC if the " + "value state is locked") + return orig_args() + + +@override_method(backend_DirichletBC, "set_value") +def DirichletBC_set_value(self, orig, orig_args, *args, **kwargs): + bc_value = getattr(self, "_tlm_adjoint__bc_value", None) + if is_var(bc_value) and var_state_is_locked(bc_value): + raise VariableStateChangeError("Cannot change DirichletBC if the " + "value state is locked") + return orig_args() + + @override_method(backend_DirichletBC, "apply") -def _DirichletBC_apply(self, orig, orig_args, *args): +def DirichletBC_apply(self, orig, orig_args, *args): A = None b = None x = None diff --git a/tlm_adjoint/fenics/equations.py b/tlm_adjoint/fenics/equations.py index 0f61082d..163afb2b 100644 --- a/tlm_adjoint/fenics/equations.py +++ b/tlm_adjoint/fenics/equations.py @@ -10,9 +10,9 @@ TestFunction, TrialFunction, adjoint, backend_Constant, backend_DirichletBC, backend_Function, parameters) from ..interface import ( - check_space_type, is_var, var_assign, var_id, var_is_scalar, var_new, - var_new_conjugate_dual, var_replacement, var_scalar_value, var_space, - var_zero) + check_space_type, is_var, var_assign, var_id, var_increment_state_lock, + var_is_scalar, var_new, var_new_conjugate_dual, var_replacement, + var_scalar_value, var_space, var_zero) from .backend_code_generator_interface import ( assemble, assemble_linear_solver, copy_parameters_dict, form_compiler_quadrature_parameters, homogenize, interpolate_expression, @@ -29,6 +29,7 @@ ReplacementConstant, bcs_is_cached, bcs_is_homogeneous, bcs_is_static, derivative, eliminate_zeros, extract_coefficients) +import itertools import numpy as np try: import ufl_legacy as ufl @@ -391,6 +392,15 @@ def __init__(self, eq, x, bcs=None, *, hbcs = tuple(map(homogenized_bc, bcs)) + class DirichletBCLock: + pass + + bc_lock = DirichletBCLock() + for bc in itertools.chain(bcs, hbcs): + bc_value = getattr(bc, "_tlm_adjoint__bc_value", None) + if is_var(bc_value): + var_increment_state_lock(bc_value, bc_lock) + if cache_jacobian is None: cache_jacobian = is_cached(J) and bcs_is_cached(bcs) if cache_adjoint_jacobian is None: @@ -429,6 +439,7 @@ def __init__(self, eq, x, bcs=None, *, self._rhs = rhs self._bcs = bcs self._hbcs = hbcs + self._bc_lock = bc_lock self._J = J self._nl_solve_J = nl_solve_J self._form_compiler_parameters = form_compiler_parameters diff --git a/tlm_adjoint/fenics/functions.py b/tlm_adjoint/fenics/functions.py index 05396a5c..f5f0512a 100644 --- a/tlm_adjoint/fenics/functions.py +++ b/tlm_adjoint/fenics/functions.py @@ -9,11 +9,11 @@ TestFunction, TrialFunction, backend_Constant, backend_DirichletBC, backend_ScalarType) from ..interface import ( - DEFAULT_COMM, SpaceInterface, add_interface, comm_parent, is_var, - space_comm, var_caches, var_comm, var_dtype, var_derivative_space, var_id, - var_is_cached, var_is_replacement, var_is_static, var_linf_norm, - var_lock_state, var_name, var_replacement, var_scalar_value, var_space, - var_space_type) + DEFAULT_COMM, SpaceInterface, VariableStateChangeError, add_interface, + comm_parent, is_var, space_comm, var_caches, var_comm, var_dtype, + var_derivative_space, var_id, var_increment_state_lock, var_is_cached, + var_is_replacement, var_is_static, var_linf_norm, var_lock_state, var_name, + var_replacement, var_scalar_value, var_space, var_space_type) from ..interface import VariableInterface as _VariableInterface from ..caches import Caches @@ -320,7 +320,8 @@ class Zero: """ def _tlm_adjoint__var_interface_update_state(self): - raise RuntimeError("Cannot call _update_state interface of Zero") + raise VariableStateChangeError("Cannot call _update_state interface " + "of Zero") class ZeroConstant(Constant, Zero): @@ -449,32 +450,14 @@ def __init__(self, V, g, sub_domain, *args, else: static = True + if static and is_var(g): + var_increment_state_lock(g, self) + + self._tlm_adjoint__bc_value = g self._tlm_adjoint__static = static self._tlm_adjoint__cache = static self._tlm_adjoint__homogeneous = _homogeneous - def homogenize(self): - """Homogenize the :class:`.DirichletBC`, setting its value to zero. - """ - - if self._tlm_adjoint__static: - raise RuntimeError("Cannot call homogenize method for static " - "DirichletBC") - if not self._tlm_adjoint__homogeneous: - super().homogenize() - self._tlm_adjoint__homogeneous = True - - def set_value(self, *args, **kwargs): - """Set the :class:`.DirichletBC` value. - - Arguments are passed to the DOLFIN `DirichletBC.set_value` method. - """ - - if self._tlm_adjoint__static: - raise RuntimeError("Cannot call set_value method for static " - "DirichletBC") - super().set_value(*args, **kwargs) - class HomogeneousDirichletBC(DirichletBC): """A :class:`.DirichletBC` whose value is zero. diff --git a/tlm_adjoint/firedrake/backend_overrides.py b/tlm_adjoint/firedrake/backend_overrides.py index 5bb90630..ea74c950 100644 --- a/tlm_adjoint/firedrake/backend_overrides.py +++ b/tlm_adjoint/firedrake/backend_overrides.py @@ -7,14 +7,15 @@ backend_Function, backend_Vector, backend_assemble, backend_interpolate, backend_project, backend_solve, parameters) from ..interface import ( - is_var, space_id, var_comm, var_new, var_space, var_update_state) + VariableStateChangeError, is_var, space_id, var_comm, var_new, var_space, + var_state_is_locked, var_update_state) from .backend_code_generator_interface import ( copy_parameters_dict, update_parameters_dict) from ..equation import ZeroAssignment from ..equations import Assignment from ..override import ( - add_manager_controls, manager_method, override_method) + add_manager_controls, manager_method, override_method, override_property) from .equations import ( Assembly, EquationSolver, ExprInterpolation, Projection, expr_new_x, @@ -109,6 +110,23 @@ def var_update_state_post_call(self, return_value, *args, **kwargs): return return_value +def DirichletBC_function_arg_fset(self, orig, orig_args, g): + if getattr(self, "_tlm_adjoint__function_arg_set", False) \ + and is_var(self.function_arg) \ + and var_state_is_locked(self.function_arg): + raise VariableStateChangeError("Cannot change DirichletBC if the " + "value state is locked") + return_value = orig_args() + self._tlm_adjoint__function_arg_set = True + return return_value + + +@override_property(backend_DirichletBC, "function_arg", + fset=DirichletBC_function_arg_fset) +def DirichletBC_function_arg(self, orig): + return orig() + + @manager_method(backend_Constant, "assign", post_call=var_update_state_post_call) def Constant_assign(self, orig, orig_args, value, *, annotate, tlm): diff --git a/tlm_adjoint/firedrake/equations.py b/tlm_adjoint/firedrake/equations.py index 3e7ae8a9..0ceb3c9c 100644 --- a/tlm_adjoint/firedrake/equations.py +++ b/tlm_adjoint/firedrake/equations.py @@ -10,9 +10,9 @@ TestFunction, TrialFunction, adjoint, backend_Constant, backend_DirichletBC, backend_Function, parameters) from ..interface import ( - check_space_type, is_var, var_assign, var_id, var_is_scalar, var_new, - var_new_conjugate_dual, var_replacement, var_scalar_value, var_space, - var_zero) + check_space_type, is_var, var_assign, var_id, var_increment_state_lock, + var_is_scalar, var_new, var_new_conjugate_dual, var_replacement, + var_scalar_value, var_space, var_zero) from .backend_code_generator_interface import ( assemble, assemble_linear_solver, copy_parameters_dict, form_compiler_quadrature_parameters, homogenize, interpolate_expression, @@ -29,6 +29,7 @@ ReplacementConstant, bcs_is_cached, bcs_is_homogeneous, bcs_is_static, derivative, eliminate_zeros, extract_coefficients) +import itertools import numpy as np import ufl @@ -389,6 +390,15 @@ def __init__(self, eq, x, bcs=None, *, hbcs = tuple(map(homogenized_bc, bcs)) + class DirichletBCLock: + pass + + bc_lock = DirichletBCLock() + for bc in itertools.chain(bcs, hbcs): + bc_value = bc.function_arg + if is_var(bc_value): + var_increment_state_lock(bc_value, bc_lock) + if cache_jacobian is None: cache_jacobian = is_cached(J) and bcs_is_cached(bcs) if cache_adjoint_jacobian is None: @@ -427,6 +437,7 @@ def __init__(self, eq, x, bcs=None, *, self._rhs = rhs self._bcs = bcs self._hbcs = hbcs + self._bc_lock = bc_lock self._J = J self._nl_solve_J = nl_solve_J self._form_compiler_parameters = form_compiler_parameters diff --git a/tlm_adjoint/firedrake/functions.py b/tlm_adjoint/firedrake/functions.py index 6f7fb759..bc0796df 100644 --- a/tlm_adjoint/firedrake/functions.py +++ b/tlm_adjoint/firedrake/functions.py @@ -9,11 +9,11 @@ TestFunction, TrialFunction, backend_Constant, backend_DirichletBC, backend_ScalarType) from ..interface import ( - DEFAULT_COMM, SpaceInterface, add_interface, comm_parent, is_var, - space_comm, var_caches, var_comm, var_dtype, var_derivative_space, var_id, - var_is_cached, var_is_replacement, var_is_static, var_linf_norm, - var_lock_state, var_name, var_replacement, var_scalar_value, var_space, - var_space_type) + DEFAULT_COMM, SpaceInterface, VariableStateChangeError, add_interface, + comm_parent, is_var, space_comm, var_caches, var_comm, var_dtype, + var_derivative_space, var_id, var_increment_state_lock, var_is_cached, + var_is_replacement, var_is_static, var_linf_norm, var_lock_state, var_name, + var_replacement, var_scalar_value, var_space, var_space_type) from ..interface import VariableInterface as _VariableInterface from ..caches import Caches @@ -337,17 +337,9 @@ class Zero: variables for which UFL zero elimination should not be applied. """ - def _tlm_adjoint__var_interface_assign(self, y): - raise RuntimeError("Cannot call _assign interface of Zero") - - def _tlm_adjoint__var_interface_axpy(self, alpha, x, /): - raise RuntimeError("Cannot call _axpy interface of Zero") - - def _tlm_adjoint__var_interface_set_values(self, values): - raise RuntimeError("Cannot call _set_values interface of Zero") - def _tlm_adjoint__var_interface_update_state(self): - raise RuntimeError("Cannot call _update_state interface of Zero") + raise VariableStateChangeError("Cannot call _update_state interface " + "of Zero") class ZeroConstant(Constant, Zero): @@ -560,32 +552,13 @@ def __init__(self, V, g, sub_domain, *args, else: static = True + if static and is_var(self.function_arg): + var_increment_state_lock(self.function_arg, self) + self._tlm_adjoint__static = static self._tlm_adjoint__cache = static self._tlm_adjoint__homogeneous = _homogeneous - def homogenize(self): - """Homogenize the :class:`.DirichletBC`, setting its value to zero. - """ - - if self._tlm_adjoint__static: - raise RuntimeError("Cannot call homogenize method for static " - "DirichletBC") - if not self._tlm_adjoint__homogeneous: - super().homogenize() - self._tlm_adjoint__homogeneous = True - - def set_value(self, *args, **kwargs): - """Set the :class:`.DirichletBC` value. - - Arguments are passed to :meth:`firedrake.bcs.DirichletBC.set_value`. - """ - - if self._tlm_adjoint__static: - raise RuntimeError("Cannot call set_value method for static " - "DirichletBC") - super().set_value(*args, **kwargs) - class HomogeneousDirichletBC(DirichletBC): """A :class:`.DirichletBC` whose value is zero. diff --git a/tlm_adjoint/interface.py b/tlm_adjoint/interface.py index 2ac92e93..b998bfb1 100644 --- a/tlm_adjoint/interface.py +++ b/tlm_adjoint/interface.py @@ -974,6 +974,10 @@ def var_decrement_state_lock(x, obj): del obj._tlm_adjoint__state_locks[x_id] +class VariableStateChangeError(RuntimeError): + pass + + def var_lock_state(x): """Lock the state of a variable. @@ -997,7 +1001,7 @@ def var_state_is_locked(x): def var_check_state_lock(x): if var_state_is_locked(x) \ and x._tlm_adjoint__state_lock_state != var_state(x): - raise RuntimeError("State change while locked") + raise VariableStateChangeError("State change while locked") class StateLockDictionary(MutableMapping): @@ -1060,7 +1064,8 @@ def var_update_state(*X): raise ValueError("x cannot be a replacement") var_check_state_lock(x) if var_state_is_locked(x): - raise RuntimeError("Cannot update state for locked variable") + raise VariableStateChangeError("Cannot update state for locked " + "variable") x._tlm_adjoint__var_interface_update_state() var_update_caches(*X) diff --git a/tlm_adjoint/override.py b/tlm_adjoint/override.py index ad9d1f8d..109042e0 100644 --- a/tlm_adjoint/override.py +++ b/tlm_adjoint/override.py @@ -29,17 +29,29 @@ def wrapped_override(self, *args, **kwargs): def override_property(cls, name, *, - cached=False): + fset=None, cached=False): orig = getattr(cls, name) def wrapper(override): - property_decorator = functools.cached_property if cached else property + if fset is not None: + @functools.wraps(fset) + def wrapped_fset(self, *args, **kwargs): + return fset(self, orig.fset, + lambda: orig.fset(self, *args, **kwargs), + *args, **kwargs) + + if cached: + if fset is not None: + raise TypeError("Cannot use fset with a cached_property") + property_decorator = functools.cached_property + else: + def property_decorator(arg): + return property(arg, fset=wrapped_fset) @property_decorator @functools.wraps(orig) - def wrapped_override(self, *args, **kwargs): - return override(self, lambda: orig.__get__(self, type(self)), - *args, **kwargs) + def wrapped_override(self): + return override(self, lambda: orig.__get__(self, type(self))) setattr(cls, name, wrapped_override) if cached: From c5bb879603ff54f1291e4814ebbe13f91bc3b28e Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 31 Oct 2023 13:31:36 +0000 Subject: [PATCH 3/3] Firedrake backend: Parallel R0 fix --- tlm_adjoint/firedrake/backend_interface.py | 7 +++++-- tlm_adjoint/interface.py | 9 +++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tlm_adjoint/firedrake/backend_interface.py b/tlm_adjoint/firedrake/backend_interface.py index ac37a896..d781cae0 100644 --- a/tlm_adjoint/firedrake/backend_interface.py +++ b/tlm_adjoint/firedrake/backend_interface.py @@ -220,10 +220,13 @@ def _local_indices(self): return slice(*local_range) def _get_values(self): - return self.dat.data_ro.flatten().copy() + with self.dat.vec_ro as x_v: + x_a = x_v.getArray(True) + return x_a.copy() def _set_values(self, values): - self.dat.data[:] = values.reshape(self.dat.data_ro.shape)[:] + with self.dat.vec_wo as x_v: + x_v.setArray(values) def _is_replacement(self): return False diff --git a/tlm_adjoint/interface.py b/tlm_adjoint/interface.py index b998bfb1..58573e75 100644 --- a/tlm_adjoint/interface.py +++ b/tlm_adjoint/interface.py @@ -1252,7 +1252,12 @@ def var_get_values(x): freedom. """ - return x._tlm_adjoint__var_interface_get_values() + values = x._tlm_adjoint__var_interface_get_values() + if not np.can_cast(values, var_dtype(x)): + raise ValueError("Invalid dtype") + if values.shape != (var_local_size(x),): + raise ValueError("Invalid shape") + return values def var_set_values(x, values): @@ -1266,7 +1271,7 @@ def var_set_values(x, values): if not np.can_cast(values, var_dtype(x)): raise ValueError("Invalid dtype") - if not values.shape == (var_local_size(x),): + if values.shape != (var_local_size(x),): raise ValueError("Invalid shape") x._tlm_adjoint__var_interface_set_values(values) var_update_state(x)