Skip to content

Commit

Permalink
Merge pull request #610 from tlm-adjoint/jrmaddison/Cofunction_interp…
Browse files Browse the repository at this point in the history
…olation

Firedrake backend: `Cofunction` interpolation
  • Loading branch information
jrmaddison authored Dec 17, 2024
2 parents 93ea2fe + 3264348 commit 53ea515
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 21 deletions.
61 changes: 58 additions & 3 deletions tests/firedrake/test_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,8 @@ def forward_J(F):
@pytest.mark.firedrake
@pytest.mark.parametrize("degree", [1, 2, 3])
@seed_test
def test_interpolate(setup_test, test_leaks,
interpolate_expr, degree):
def test_Function_interpolate(setup_test, test_leaks,
interpolate_expr, degree):
mesh = UnitIntervalMesh(20)
X = SpatialCoordinate(mesh)
space_1 = FunctionSpace(mesh, "Lagrange", 1)
Expand All @@ -828,7 +828,6 @@ def forward(y_2):
J.assign(((y_1 - Constant(1.0)) ** 4) * dx)
return y_1, J

reset_manager("memory", {"drop_references": True})
start_manager()
y_1, J = forward(y_2)
stop_manager()
Expand Down Expand Up @@ -862,6 +861,62 @@ def forward_J(y_2):
assert min_order > 1.99


@pytest.mark.firedrake
@seed_test
def test_Cofunction_interpolate(setup_test, test_leaks):
mesh = UnitIntervalMesh(10)
X = SpatialCoordinate(mesh)
space_1 = FunctionSpace(mesh, "Lagrange", 1)
space_2 = FunctionSpace(mesh, "Lagrange", 2)

def forward(b):
y = Cofunction(space_2.dual(), name="y").interpolate(b)
y_dual = y.riesz_representation(
"L2", solver_parameters=ls_parameters_cg)
J = Functional(name="J")
J.assign(((y_dual + Constant(1)) ** 4) * dx)
return y, J

b = assemble(inner(exp(X[0]), TestFunction(space_1)) * dx)

start_manager()
y, J = forward(b)
stop_manager()

for n in range(space_2.dim()):
u = Function(space_2, name="u")
with u.dat.vec_wo as u_v:
u_v.setValue(n, 1)
u_v.assemblyBegin()
u_v.assemblyEnd()
error = abs(assemble(y(u) - b(Function(space_1).interpolate(u))))
assert error == 0

J_val = J.value

dJ = compute_gradient(J, b)

def forward_J(b):
_, J = forward(b)
return J

min_order = taylor_test(forward_J, b, J_val=J_val, dJ=dJ)
assert min_order > 2.00

ddJ = Hessian(forward_J)
min_order = taylor_test(forward_J, b, J_val=J_val, ddJ=ddJ)
assert min_order > 3.00

min_order = taylor_test_tlm(forward_J, b, tlm_order=1)
assert min_order > 2.00

min_order = taylor_test_tlm_adjoint(forward_J, b, adjoint_order=1)
assert min_order > 2.00

min_order = taylor_test_tlm_adjoint(forward_J, b, adjoint_order=2)
assert min_order > 2.00


@pytest.mark.firedrake
@pytest.mark.skipif(complex_mode, reason="real only")
@seed_test
Expand Down
3 changes: 1 addition & 2 deletions tlm_adjoint/firedrake/backend_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,12 +737,11 @@ def SameMeshInterpolator_interpolate_post_call(
def SameMeshInterpolator_interpolate(
self, orig, orig_args, *function, output=None, transpose=False,
default_missing_val=None, **kwargs):
if transpose:
raise NotImplementedError("transpose not supported")
if default_missing_val is not None:
raise NotImplementedError("default_missing_val not supported")

return_value = orig_args()
check_space_type(return_value, "conjugate_dual" if transpose else "primal")

args = ufl.algorithms.extract_arguments(self.expr)
if len(args) != len(function):
Expand Down
39 changes: 23 additions & 16 deletions tlm_adjoint/firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
FunctionSpace, Interpolator, TestFunction, VertexOnlyMesh,
backend_Cofunction, backend_Constant, backend_Function)
from ..interface import (
check_space_type, comm_dup_cached, packed, space_new, var_assign, var_comm,
var_copy, var_id, var_inner, var_is_scalar, var_new_conjugate_dual,
var_replacement, var_scalar_value)
check_space_type, check_space_types, comm_dup_cached, packed, space_new,
var_assign, var_assign_conjugate, var_axpy, var_axpy_conjugate, var_comm,
var_copy_conjugate, var_id, var_inner, var_is_scalar, var_new,
var_new_conjugate, var_new_conjugate_dual, var_replacement,
var_scalar_value, var_zero)

from ..equation import Equation, ZeroAssignment
from ..manager import manager_disabled

from .expr import (
ExprEquation, derivative, eliminate_zeros, expr_zero, extract_dependencies,
extract_variables)
iter_expr)
from .variables import ReplacementConstant

import itertools
Expand All @@ -30,19 +32,18 @@

@manager_disabled()
def interpolate_expression(x, expr, *, adj_x=None):
if adj_x is None:
check_space_type(x, "primal")
else:
check_space_type(x, "conjugate_dual")
check_space_type(adj_x, "conjugate_dual")
for dep in extract_variables(expr):
check_space_type(dep, "primal")
if adj_x is not None:
check_space_types(x, adj_x)

expr = eliminate_zeros(expr)

if adj_x is None:
if isinstance(x, backend_Constant):
x.assign(expr)
elif isinstance(x, backend_Cofunction):
var_zero(x)
for weight, comp in iter_expr(expr):
var_axpy(x, weight, var_new(x).interpolate(comp))
elif isinstance(x, backend_Function):
x.interpolate(expr)
else:
Expand All @@ -54,11 +55,17 @@ def interpolate_expression(x, expr, *, adj_x=None):
interpolate_expression(expr_val, expr)
var_assign(x, var_inner(adj_x, expr_val))
elif isinstance(x, backend_Cofunction):
adj_x = var_copy_conjugate(adj_x)
interp = Interpolator(expr, adj_x.function_space().dual())
adj_x = var_copy(adj_x)
adj_x.dat.data[:] = adj_x.dat.data_ro.conjugate()
interp._interpolate(adj_x, transpose=True, output=x)
x.dat.data[:] = x.dat.data_ro.conjugate()
x_comp = var_new_conjugate(x)
interp._interpolate(adj_x, transpose=True, output=x_comp)
var_assign_conjugate(x, x_comp)
elif isinstance(x, backend_Function):
adj_x = var_copy_conjugate(adj_x)
var_zero(x)
for weight, comp in iter_expr(expr):
x_comp = var_new_conjugate(x).interpolate(comp(adj_x))
var_axpy_conjugate(x, weight.conjugate(), x_comp)
else:
raise TypeError(f"Unexpected type: {type(x)}")

Expand All @@ -75,7 +82,7 @@ class ExprInterpolation(ExprEquation):
"""

def __init__(self, x, rhs):
deps, nl_deps = extract_dependencies(rhs, space_type="primal")
deps, nl_deps = extract_dependencies(rhs)
if var_id(x) in deps:
raise ValueError("Invalid dependency")
deps, nl_deps = list(deps.values()), tuple(nl_deps.values())
Expand Down

0 comments on commit 53ea515

Please sign in to comment.