Skip to content

Commit

Permalink
Merge pull request #470 from tlm-adjoint/jrmaddison/replacement_fix
Browse files Browse the repository at this point in the history
`Replacement` fix
  • Loading branch information
jrmaddison authored Nov 30, 2023
2 parents 4af7cdf + 44c0db5 commit e8fb4f2
Show file tree
Hide file tree
Showing 16 changed files with 355 additions and 466 deletions.
4 changes: 2 additions & 2 deletions docs/source/examples/6_custom_operations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@
" return adj_x\n",
" elif dep_index == 1:\n",
" _, y = self.dependencies()\n",
" b = Cofunction(y.function_space().dual())\n",
" b = Cofunction(var_space(y).dual())\n",
" adj_x = float(adj_x)\n",
" b.dat.data[:] = -adj_x\n",
" return b\n",
Expand Down Expand Up @@ -470,7 +470,7 @@
" return adj_x\n",
" elif dep_index == 1:\n",
" _, y = self.dependencies()\n",
" b = Cofunction(y.function_space().dual())\n",
" b = Cofunction(var_space(y).dual())\n",
" adj_x = float(adj_x)\n",
" b.dat.data[:] = -adj_x\n",
" return b\n",
Expand Down
27 changes: 25 additions & 2 deletions tests/base/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# -*- coding: utf-8 -*-

from tlm_adjoint import (
DEFAULT_COMM, Float, VariableStateLockDictionary, Vector, var_caches,
var_id, var_is_cached, var_is_static, var_lock_state, var_name,
DEFAULT_COMM, Float, SymbolicFloat, VariableStateLockDictionary, Vector,
var_caches, var_id, var_is_cached, var_is_static, var_lock_state, var_name,
var_replacement)
from tlm_adjoint.interface import (
var_decrement_state_lock, var_increment_state_lock, var_state_is_locked)
Expand Down Expand Up @@ -54,6 +54,29 @@ def test_replacement(setup_test, # noqa: F811
assert var_caches(var) is F_caches


@pytest.mark.base
@seed_test
def test_replacement_eq_hash(setup_test, # noqa: F811
var_cls):
F = var_cls()
if isinstance(F, SymbolicFloat):
pytest.skip()
F_replacement = var_replacement(F)

assert F == F
assert not (F != F)
assert F_replacement == F_replacement
assert not (F_replacement != F_replacement)

assert F != F_replacement
assert F_replacement != F
assert not (F == F_replacement)
assert not (F_replacement == F)

assert hash(F) != hash(F_replacement)
assert len(set((F, F_replacement))) == 2


@pytest.mark.base
@seed_test
def test_state_lock(setup_test): # noqa: F811
Expand Down
2 changes: 1 addition & 1 deletion tests/fenics/test_caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_clear_caches(setup_test, test_leaks):
F = Function(space, name="F", cache=True)

def cache_item(F, F_value=None):
form = inner(F, TestFunction(F.function_space())) * dx
form = inner(F, TestFunction(var_space(F))) * dx
cached_form, _ = assembly_cache().assemble(
form, replace_map=None if F_value is None else {F: F_value})
return cached_form
Expand Down
24 changes: 23 additions & 1 deletion tests/fenics/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_name(setup_test,
@pytest.mark.parametrize("static", [False, True])
@pytest.mark.parametrize("cache", [False, True, None])
@seed_test
def test_replacement(setup_test, # noqa: F811
def test_replacement(setup_test,
var_cls, cache, static):
name = "_tlm_adjoint__test_name"
F = var_cls(name=name, static=static, cache=cache)
Expand All @@ -53,6 +53,28 @@ def test_replacement(setup_test, # noqa: F811
assert var_caches(var) is F_caches


@pytest.mark.fenics
@seed_test
def test_replacement_eq_hash(setup_test,
var_cls):
F = var_cls()
F_replacement = var_replacement(F)

assert F == F
assert not (F != F)
assert F_replacement == F_replacement
assert not (F_replacement != F_replacement)

assert F != F_replacement
assert F_replacement != F
assert not (F == F_replacement)
assert not (F_replacement == F)

assert F.count() != F_replacement.count()
assert hash(F) != hash(F_replacement)
assert len(set((F, F_replacement))) == 2


@pytest.mark.fenics
@seed_test
def test_FunctionSpace_interface(setup_test, test_leaks):
Expand Down
2 changes: 1 addition & 1 deletion tests/firedrake/test_caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_clear_caches(setup_test, test_leaks):
F = Function(space, name="F", cache=True)

def cache_item(F, F_value=None):
form = inner(F, TestFunction(F.function_space())) * dx
form = inner(F, TestFunction(var_space(F))) * dx
cached_form, _ = assembly_cache().assemble(
form, replace_map=None if F_value is None else {F: F_value})
return cached_form
Expand Down
27 changes: 25 additions & 2 deletions tests/firedrake/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

@pytest.fixture(params=[{"cls": lambda **kwargs: Constant(**kwargs)},
{"cls": lambda **kwargs: Constant(domain=UnitIntervalMesh(20), **kwargs)}, # noqa: E501
{"cls": lambda **kwargs: Function(FunctionSpace(UnitIntervalMesh(20), "Lagrange", 1), **kwargs)}]) # noqa: E501
{"cls": lambda **kwargs: Function(FunctionSpace(UnitIntervalMesh(20), "Lagrange", 1), **kwargs)}, # noqa: E501
{"cls": lambda **kwargs: Cofunction(FunctionSpace(UnitIntervalMesh(20), "Lagrange", 1).dual(), **kwargs)}]) # noqa: E501
def var_cls(request):
return request.param["cls"]

Expand All @@ -36,7 +37,7 @@ def test_name(setup_test,
@pytest.mark.parametrize("static", [False, True])
@pytest.mark.parametrize("cache", [False, True, None])
@seed_test
def test_replacement(setup_test, # noqa: F811
def test_replacement(setup_test,
var_cls, cache, static):
name = "_tlm_adjoint__test_name"
F = var_cls(name=name, static=static, cache=cache)
Expand All @@ -53,6 +54,28 @@ def test_replacement(setup_test, # noqa: F811
assert var_caches(var) is F_caches


@pytest.mark.firedrake
@seed_test
def test_replacement_eq_hash(setup_test,
var_cls):
F = var_cls()
F_replacement = var_replacement(F)

assert F == F
assert not (F != F)
assert F_replacement == F_replacement
assert not (F_replacement != F_replacement)

assert F != F_replacement
assert F_replacement != F
assert not (F == F_replacement)
assert not (F_replacement == F)

assert F.count() != F_replacement.count()
assert hash(F) != hash(F_replacement)
assert len(set((F, F_replacement))) == 2


@pytest.mark.firedrake
@seed_test
def test_FunctionSpace_interface(setup_test, test_leaks):
Expand Down
1 change: 1 addition & 0 deletions tlm_adjoint/fenics/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@
backend_solve = solve

cpp_Assembler = fenics.cpp.fem.Assembler
cpp_Constant = fenics.cpp.function.Constant
cpp_PETScVector = fenics.cpp.la.PETScVector
cpp_SystemAssembler = fenics.cpp.fem.SystemAssembler
49 changes: 2 additions & 47 deletions tlm_adjoint/fenics/backend_code_generator_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
# -*- coding: utf-8 -*-

from .backend import (
FunctionSpace, LUSolver, KrylovSolver, Parameters, TensorFunctionSpace,
TestFunction, UserExpression, VectorFunctionSpace, as_backend_type,
backend_Constant, backend_DirichletBC, backend_Function,
LUSolver, KrylovSolver, Parameters, TestFunction, UserExpression,
as_backend_type, backend_Constant, backend_DirichletBC, backend_Function,
backend_ScalarType, backend_Vector, backend_assemble,
backend_assemble_system, backend_solve as solve, complex_mode,
has_lu_solver_method, parameters)
Expand Down Expand Up @@ -262,50 +261,6 @@ def matrix_multiply(A, x, *,
return tensor


@manager_disabled()
def is_valid_r0_space(space):
if not hasattr(space, "_tlm_adjoint__is_valid_r0_space"):
e = space.ufl_element()
if e.family() != "Real" or e.degree() != 0:
valid = False
elif len(e.value_shape()) == 0:
r = backend_Function(space)
r.assign(backend_Constant(1.0))
valid = (r.vector().max() == 1.0)
else:
r = backend_Function(space)
r_arr = np.arange(1, np.prod(r.ufl_shape) + 1,
dtype=backend_ScalarType)
r_arr.shape = r.ufl_shape
r.assign(backend_Constant(r_arr))
for i, r_c in enumerate(r.split(deepcopy=True)):
if r_c.vector().max() != i + 1:
valid = False
break
else:
valid = True
space._tlm_adjoint__is_valid_r0_space = valid
return space._tlm_adjoint__is_valid_r0_space


def r0_space(x):
if not hasattr(x, "_tlm_adjoint__r0_space"):
domain = var_space(x)._tlm_adjoint__space_interface_attrs["domain"]
domain = domain.ufl_cargo()
if len(x.ufl_shape) == 0:
space = FunctionSpace(domain, "R", 0)
elif len(x.ufl_shape) == 1:
space = VectorFunctionSpace(domain, "R", 0,
dim=ufl.shape[0])
else:
space = TensorFunctionSpace(domain, "R", degree=0,
shape=x.ufl_shape)
if not is_valid_r0_space(space):
raise RuntimeError("Invalid space")
x._tlm_adjoint__r0_space = space
return x._tlm_adjoint__r0_space


def rhs_copy(x):
if not isinstance(x, backend_Vector):
raise TypeError("Invalid RHS")
Expand Down
38 changes: 19 additions & 19 deletions tlm_adjoint/fenics/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@
backend_ScalarType, backend_Vector, cpp_PETScVector)
from ..interface import (
DEFAULT_COMM, SpaceInterface, VariableInterface, add_interface,
check_space_types, comm_dup_cached, new_space_id, new_var_id,
check_space_types, comm_dup_cached, is_var, new_space_id, new_var_id,
register_functional_term_eq, register_subtract_adjoint_derivative_action,
space_id, subtract_adjoint_derivative_action_base, var_axpy, var_copy,
var_linf_norm, var_lock_state, var_new, var_space, var_space_type)
from .backend_code_generator_interface import r0_space

from ..equations import Conversion
from ..override import override_method

from .equations import Assembly
from .functions import (
Caches, ConstantInterface, ConstantSpaceInterface, ReplacementFunction,
Zero, define_var_alias)
Zero, define_var_alias, new_count, r0_space)

import functools
import numbers
Expand All @@ -44,7 +43,10 @@ def Constant__init__(self, orig, orig_args, *args, domain=None, space=None,
if domain is not None and hasattr(domain, "ufl_domain"):
domain = domain.ufl_domain()
if comm is None:
comm = DEFAULT_COMM
if domain is None:
comm = DEFAULT_COMM
else:
comm = domain.ufl_cargo().mpi_comm()

orig(self, *args, **kwargs)

Expand All @@ -56,9 +58,9 @@ def Constant__init__(self, orig, orig_args, *args, domain=None, space=None,
add_interface(self, ConstantInterface,
{"id": new_var_id(), "name": lambda x: x.name(),
"state": [0], "space": space,
"derivative_space": lambda x: r0_space(x),
"space_type": "primal", "dtype": self.values().dtype.type,
"static": False, "cache": False})
"static": False, "cache": False,
"replacement_count": new_count()})


class FunctionSpaceInterface(SpaceInterface):
Expand Down Expand Up @@ -135,9 +137,6 @@ class FunctionInterface(VariableInterface):
def _space(self):
return self._tlm_adjoint__var_interface_attrs["space"]

def _derivative_space(self):
return var_space(self)

def _space_type(self):
return self._tlm_adjoint__var_interface_attrs["space_type"]

Expand Down Expand Up @@ -247,9 +246,11 @@ def _copy(self, *, name=None, static=False, cache=None):
return y

def _replacement(self):
if not hasattr(self, "_tlm_adjoint__replacement"):
self._tlm_adjoint__replacement = ReplacementFunction(self)
return self._tlm_adjoint__replacement
if "replacement" not in self._tlm_adjoint__var_interface_attrs:
count = self._tlm_adjoint__var_interface_attrs["replacement_count"]
self._tlm_adjoint__var_interface_attrs["replacement"] = \
ReplacementFunction(self, count=count)
return self._tlm_adjoint__var_interface_attrs["replacement"]

def _is_replacement(self):
return False
Expand Down Expand Up @@ -318,25 +319,24 @@ def Function__init__(self, orig, orig_args, *args, **kwargs):
if not isinstance(as_backend_type(self.vector()), cpp_PETScVector):
raise RuntimeError("PETSc backend required")

space = self.function_space()
add_interface(self, FunctionInterface,
{"id": new_var_id(), "state": [0],
"space_type": "primal", "static": False, "cache": False})
{"id": new_var_id(), "state": [0], "space": space,
"space_type": "primal", "static": False, "cache": False,
"replacement_count": new_count()})

space = self.function_space()
if isinstance(args[0], backend_FunctionSpace) and args[0].id() == space.id(): # noqa: E501
id = space_id(args[0])
else:
id = new_space_id()
add_interface(space, FunctionSpaceInterface,
{"comm": comm_dup_cached(space.mesh().mpi_comm()), "id": id})
self._tlm_adjoint__var_interface_attrs["space"] = space


@override_method(backend_Function, "function_space")
def Function_function_space(self, orig, orig_args):
if hasattr(self, "_tlm_adjoint__var_interface_attrs") \
and "space" in self._tlm_adjoint__var_interface_attrs:
return self._tlm_adjoint__var_interface_attrs["space"]
if is_var(self):
return var_space(self)
else:
return orig_args()

Expand Down
Loading

0 comments on commit e8fb4f2

Please sign in to comment.