Skip to content

Commit

Permalink
Merge pull request #423 from tlm-adjoint/jrmaddison/state_locking
Browse files Browse the repository at this point in the history
Expand use of variable state locks
  • Loading branch information
jrmaddison authored Oct 31, 2023
2 parents c1ee706 + c5bb879 commit 3fd05c9
Show file tree
Hide file tree
Showing 12 changed files with 291 additions and 96 deletions.
73 changes: 73 additions & 0 deletions tests/fenics/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# -*- coding: utf-8 -*-

from fenics import *
from tlm_adjoint.alias import WeakAlias
from tlm_adjoint.interface import VariableStateChangeError
from tlm_adjoint.fenics import *

from .test_base import *

import gc
import pytest

pytestmark = pytest.mark.skipif(
Expand Down Expand Up @@ -103,3 +106,73 @@ def test_scalar(x, ref):
test_scalar(Constant(0.0), True)
test_scalar(Constant((0.0, 0.0)), False)
test_scalar(Function(space), False)


@pytest.mark.fenics
@seed_test
def test_DirichletBC_locking(setup_test, test_leaks):
def new_bc(space_cls):
mesh = UnitSquareMesh(10, 10)
space = space_cls(mesh, "Lagrange", 1)
bc_value = Function(space, name="bc_value")
bc = DirichletBC(space, bc_value, "on_boundary", static=True)
return space, bc

def new_hbc(space_cls):
mesh = UnitSquareMesh(10, 10)
space = space_cls(mesh, "Lagrange", 1)
bc = HomogeneousDirichletBC(space, "on_boundary")
return space, bc

for space_cls in (FunctionSpace,
VectorFunctionSpace,
TensorFunctionSpace):
_, bc = new_bc(space_cls)
try:
bc.homogenize()
assert False
except VariableStateChangeError:
pass
del bc

space, bc = new_bc(space_cls)
try:
bc.set_value(Function(space))
assert False
except VariableStateChangeError:
pass
del space, bc

_, bc = new_hbc(space_cls)
bc.homogenize()
del bc

def new_eq():
mesh = UnitSquareMesh(10, 10)
space = FunctionSpace(mesh, "Lagrange", 1)
bc_value = Function(space, name="bc_value")
bc = DirichletBC(space, bc_value, "on_boundary")
test, trial = TestFunction(space), TrialFunction(space)
u = Function(space, name="u")
eq = EquationSolver(inner(trial, test) * dx
== inner(Constant(1.0), test) * dx, u, bc)
return space, bc, eq

_, bc, eq = new_eq()
try:
bc.homogenize()
assert False
except VariableStateChangeError:
pass
del bc, eq

_, bc, eq = new_eq()
eq = WeakAlias(eq)
gc.collect()
garbage_cleanup()
try:
bc.homogenize()
assert False
except VariableStateChangeError:
pass
del bc, eq
81 changes: 81 additions & 0 deletions tests/firedrake/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# -*- coding: utf-8 -*-

from firedrake import *
from tlm_adjoint.alias import WeakAlias
from tlm_adjoint.interface import VariableStateChangeError
from tlm_adjoint.firedrake import *

from .test_base import *

import gc
import pytest
import ufl

Expand Down Expand Up @@ -153,3 +156,81 @@ def test_scalar(x, ref):
test_scalar(Constant((0.0, 0.0)), False)
test_scalar(Function(space), False)
test_scalar(Cofunction(space.dual()), False)


@pytest.mark.firedrake
@seed_test
def test_DirichletBC_locking(setup_test, test_leaks):
def new_bc(space_cls):
mesh = UnitSquareMesh(10, 10)
space = space_cls(mesh, "Lagrange", 1)
bc_value = Function(space, name="bc_value")
bc = DirichletBC(space, bc_value, "on_boundary", static=True)
return space, bc

def new_hbc(space_cls):
mesh = UnitSquareMesh(10, 10)
space = space_cls(mesh, "Lagrange", 1)
bc = HomogeneousDirichletBC(space, "on_boundary")
return space, bc

for space_cls in (FunctionSpace,
VectorFunctionSpace,
TensorFunctionSpace):
_, bc = new_bc(space_cls)
try:
bc.homogenize()
assert False
except VariableStateChangeError:
pass
del bc

space, bc = new_bc(space_cls)
try:
bc.set_value(Function(space))
assert False
except VariableStateChangeError:
pass
del space, bc

space, bc = new_bc(space_cls)
try:
bc.function_arg = Function(space)
assert False
except VariableStateChangeError:
pass
del space, bc

_, bc = new_hbc(space_cls)
bc.homogenize()
del bc

def new_eq():
mesh = UnitSquareMesh(10, 10)
space = FunctionSpace(mesh, "Lagrange", 1)
bc_value = Function(space, name="bc_value")
bc = DirichletBC(space, bc_value, "on_boundary")
test, trial = TestFunction(space), TrialFunction(space)
u = Function(space, name="u")
eq = EquationSolver(inner(trial, test) * dx
== inner(Constant(1.0), test) * dx, u, bc)
return space, bc, eq

_, bc, eq = new_eq()
try:
bc.homogenize()
assert False
except VariableStateChangeError:
pass
del bc, eq

_, bc, eq = new_eq()
eq = WeakAlias(eq)
gc.collect()
garbage_cleanup()
try:
bc.homogenize()
assert False
except VariableStateChangeError:
pass
del bc, eq
20 changes: 11 additions & 9 deletions tlm_adjoint/cached_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# -*- coding: utf-8 -*-

from .interface import (
StateLockDictionary, var_id, var_new, var_scalar_value, var_state)
StateLockDictionary, var_check_state_lock, var_id,
var_increment_state_lock, var_new, var_scalar_value)

from .caches import clear_caches
from .hessian import GaussNewton, Hessian
Expand Down Expand Up @@ -146,10 +147,11 @@ class CachedHessian(Hessian, HessianOptimization):
"""

def __init__(self, J, *, manager=None, cache_adjoint=True):
var_increment_state_lock(J, self)

HessianOptimization.__init__(self, manager=manager,
cache_adjoint=cache_adjoint)
Hessian.__init__(self)
self._J_state = var_state(J)
self._J = J

@restore_manager
Expand All @@ -160,8 +162,7 @@ def compute_gradient(self, M, M0=None):
M0=None if M0 is None else (M0,))
return J_val, dJ

if var_state(self._J) != self._J_state:
raise RuntimeError("State has changed")
var_check_state_lock(self._J)

dM = tuple(map(var_new, M))
manager, M, dM = self._setup_manager(M, dM, M0=M0, solve_tlm=False)
Expand All @@ -185,8 +186,7 @@ def action(self, M, dM, M0=None):
M0=None if M0 is None else (M0,))
return J_val, dJ_val, ddJ

if var_state(self._J) != self._J_state:
raise RuntimeError("State has changed")
var_check_state_lock(self._J)

manager, M, dM = self._setup_manager(M, dM, M0=M0, solve_tlm=True)
set_manager(manager)
Expand Down Expand Up @@ -220,12 +220,14 @@ def __init__(self, X, R_inv_action, B_inv_action=None, *,
if not isinstance(X, Sequence):
X = (X,)

for x in X:
var_increment_state_lock(x, self)

HessianOptimization.__init__(self, manager=manager,
cache_adjoint=False)
GaussNewton.__init__(
self, R_inv_action, B_inv_action=B_inv_action)
self._X = tuple(X)
self._X_state = tuple(map(var_state, X))

def _setup_manager(self, M, dM, M0=None, *,
annotate_tlm=False, solve_tlm=True):
Expand All @@ -235,7 +237,7 @@ def _setup_manager(self, M, dM, M0=None, *,
return manager, M, dM, self._X

def action(self, M, dM, M0=None):
if tuple(map(var_state, self._X)) != self._X_state:
raise RuntimeError("State has changed")
for x in self._X:
var_check_state_lock(x)

return GaussNewton.action(self, M, dM, M0=M0)
24 changes: 21 additions & 3 deletions tlm_adjoint/fenics/backend_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
backend_Vector, backend_project, backend_solve, cpp_Assembler,
cpp_SystemAssembler, parameters)
from ..interface import (
space_id, space_new, var_assign, var_comm, var_new, var_space,
var_update_state)
VariableStateChangeError, is_var, space_id, space_new, var_assign,
var_comm, var_new, var_space, var_state_is_locked, var_update_state)
from .backend_code_generator_interface import (
copy_parameters_dict, update_parameters_dict)

Expand Down Expand Up @@ -316,8 +316,26 @@ def project(orig, orig_args, v, V=None, bcs=None, mesh=None, function=None,
return return_value


@override_method(backend_DirichletBC, "homogenize")
def DirichletBC_homogenize(self, orig, orig_args, *args, **kwargs):
bc_value = getattr(self, "_tlm_adjoint__bc_value", None)
if is_var(bc_value) and var_state_is_locked(bc_value):
raise VariableStateChangeError("Cannot change DirichletBC if the "
"value state is locked")
return orig_args()


@override_method(backend_DirichletBC, "set_value")
def DirichletBC_set_value(self, orig, orig_args, *args, **kwargs):
bc_value = getattr(self, "_tlm_adjoint__bc_value", None)
if is_var(bc_value) and var_state_is_locked(bc_value):
raise VariableStateChangeError("Cannot change DirichletBC if the "
"value state is locked")
return orig_args()


@override_method(backend_DirichletBC, "apply")
def _DirichletBC_apply(self, orig, orig_args, *args):
def DirichletBC_apply(self, orig, orig_args, *args):
A = None
b = None
x = None
Expand Down
17 changes: 14 additions & 3 deletions tlm_adjoint/fenics/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
TestFunction, TrialFunction, adjoint, backend_Constant,
backend_DirichletBC, backend_Function, parameters)
from ..interface import (
check_space_type, is_var, var_assign, var_id, var_is_scalar, var_new,
var_new_conjugate_dual, var_replacement, var_scalar_value, var_space,
var_zero)
check_space_type, is_var, var_assign, var_id, var_increment_state_lock,
var_is_scalar, var_new, var_new_conjugate_dual, var_replacement,
var_scalar_value, var_space, var_zero)
from .backend_code_generator_interface import (
assemble, assemble_linear_solver, copy_parameters_dict,
form_compiler_quadrature_parameters, homogenize, interpolate_expression,
Expand All @@ -29,6 +29,7 @@
ReplacementConstant, bcs_is_cached, bcs_is_homogeneous, bcs_is_static,
derivative, eliminate_zeros, extract_coefficients)

import itertools
import numpy as np
try:
import ufl_legacy as ufl
Expand Down Expand Up @@ -391,6 +392,15 @@ def __init__(self, eq, x, bcs=None, *,

hbcs = tuple(map(homogenized_bc, bcs))

class DirichletBCLock:
pass

bc_lock = DirichletBCLock()
for bc in itertools.chain(bcs, hbcs):
bc_value = getattr(bc, "_tlm_adjoint__bc_value", None)
if is_var(bc_value):
var_increment_state_lock(bc_value, bc_lock)

if cache_jacobian is None:
cache_jacobian = is_cached(J) and bcs_is_cached(bcs)
if cache_adjoint_jacobian is None:
Expand Down Expand Up @@ -429,6 +439,7 @@ def __init__(self, eq, x, bcs=None, *,
self._rhs = rhs
self._bcs = bcs
self._hbcs = hbcs
self._bc_lock = bc_lock
self._J = J
self._nl_solve_J = nl_solve_J
self._form_compiler_parameters = form_compiler_parameters
Expand Down
Loading

0 comments on commit 3fd05c9

Please sign in to comment.