Skip to content

Commit

Permalink
Merge pull request #609 from tlm-adjoint/jrmaddison/backwards_compati…
Browse files Browse the repository at this point in the history
…bility

Firedrake backend: Remove some backwards compatibility
  • Loading branch information
jrmaddison authored Dec 17, 2024
2 parents c4e2ff7 + e70e57b commit 93ea2fe
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 40 deletions.
33 changes: 3 additions & 30 deletions tests/firedrake/test_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,11 @@ def project_project(F, space, bc):


def project_solve(F, space, bc, *, restrict=False):
# Backwards compatibility
try:
RestrictedFunctionSpace
restrict_kwargs = {"restrict": restrict}
except NameError:
if restrict:
pytest.skip()
restrict_kwargs = {}

test, trial = TestFunction(space), TrialFunction(space)
G = Function(space, name="G")

solve(inner(trial, test) * dx == inner(F, test) * dx, G, bcs=bc,
solver_parameters=ls_parameters_cg, **restrict_kwargs)
solver_parameters=ls_parameters_cg, restrict=restrict)

return G

Expand All @@ -98,21 +89,12 @@ def project_assemble_LinearSolver(F, space, bc):


def project_LinearVariationalSolver(F, space, bc, *, restrict=False):
# Backwards compatibility
try:
RestrictedFunctionSpace
restrict_kwargs = {"restrict": restrict}
except NameError:
if restrict:
pytest.skip()
restrict_kwargs = {}

test, trial = TestFunction(space), TrialFunction(space)
G = Function(space, name="G")

eq = inner(trial, test) * dx == inner(F, test) * dx
problem = LinearVariationalProblem(eq.lhs, eq.rhs, G, bcs=bc,
**restrict_kwargs)
restrict=restrict)
solver = LinearVariationalSolver(
problem, solver_parameters=ls_parameters_cg)
solver.solve()
Expand Down Expand Up @@ -141,22 +123,13 @@ def project_LinearVariationalSolver_matfree(F, space, bc):


def project_NonlinearVariationalSolver(F, space, bc, *, restrict=False):
# Backwards compatibility
try:
RestrictedFunctionSpace
restrict_kwargs = {"restrict": restrict}
except NameError:
if restrict:
pytest.skip()
restrict_kwargs = {}

test, trial = TestFunction(space), TrialFunction(space)
G = Function(space, name="G")

eq = inner(G, test) * dx - inner(F, test) * dx
problem = NonlinearVariationalProblem(eq, G,
J=inner(trial, test) * dx,
bcs=bc, **restrict_kwargs)
bcs=bc, restrict=restrict)
solver = NonlinearVariationalSolver(
problem, solver_parameters=ns_parameters_newton_cg)
solver.solve()
Expand Down
14 changes: 4 additions & 10 deletions tlm_adjoint/firedrake/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,9 @@ def solve(*args, **kwargs):

extracted_args = extract_args(*args, **kwargs)
eq, x, bcs, J, Jp, M, form_compiler_parameters, solver_parameters, \
nullspace, transpose_nullspace, near_nullspace, options_prefix \
= extracted_args[:12]
# Backwards compatibility
if len(extracted_args) == 12:
restrict_kwargs = {}
elif len(extracted_args) == 13:
restrict_kwargs = {"restrict": extracted_args[12]}
else:
raise ValueError("Invalid extracted arguments")
nullspace, transpose_nullspace, near_nullspace, options_prefix, \
restrict \
= extracted_args
check_space_type(x, "primal")
if bcs is None:
bcs = ()
Expand Down Expand Up @@ -276,7 +270,7 @@ def solve(*args, **kwargs):
transpose_nullspace=transpose_nullspace,
near_nullspace=near_nullspace,
options_prefix=options_prefix,
**restrict_kwargs)
restrict=restrict)


class LocalSolver:
Expand Down

0 comments on commit 93ea2fe

Please sign in to comment.