Skip to content

Commit

Permalink
Merge pull request #590 from tlm-adjoint/jrmaddison/garbage_cleanup
Browse files Browse the repository at this point in the history
Attempt to work around `garbage_cleanup` deadlock
  • Loading branch information
jrmaddison authored Jul 18, 2024
2 parents dbde2ec + 924dd7a commit 4bcc3ff
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 18 deletions.
11 changes: 6 additions & 5 deletions tests/base/test_jax.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from tlm_adjoint import (
DEFAULT_COMM, DotProduct, Float, Hessian, Vector, VectorEquation,
comm_parent, compute_gradient, new_jax_float, set_default_float_dtype,
set_default_jax_dtype, start_manager, stop_manager, taylor_test,
taylor_test_tlm, taylor_test_tlm_adjoint, to_float, var_comm,
var_get_values, var_global_size, var_is_scalar, var_linf_norm,
var_local_size, var_scalar_value)
comm_parent, compute_gradient, garbage_cleanup, new_jax_float,
set_default_float_dtype, set_default_jax_dtype, start_manager,
stop_manager, taylor_test, taylor_test_tlm, taylor_test_tlm_adjoint,
to_float, var_comm, var_get_values, var_global_size, var_is_scalar,
var_linf_norm, var_local_size, var_scalar_value)

from .test_base import jax_tlm_config, seed_test, setup_test # noqa: F401

Expand Down Expand Up @@ -305,4 +305,5 @@ def test_jax_to_float(setup_test, # noqa: F811
assert var_is_scalar(x)
assert var_scalar_value(x) == x_val
finally:
garbage_cleanup(comm)
comm.Free()
16 changes: 8 additions & 8 deletions tests/firedrake/test_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def forward_J(x):

min_order = taylor_test_tlm_adjoint(forward_J, x, adjoint_order=2,
seed=1.0e-3)
assert min_order > 1.99
assert min_order > 1.98

h.close()

Expand Down Expand Up @@ -1239,25 +1239,25 @@ def forward(m):

dm = Function(space).interpolate(Constant(1.0) + X[0])

min_order = taylor_test(forward, m, J_val=J_val, dJ=dJ, seed=5.0e-3, dM=dm)
min_order = taylor_test(forward, m, J_val=J_val, dJ=dJ, seed=1.0e-4, dM=dm)
assert min_order > 1.99

ddJ = Hessian(forward)
min_order = taylor_test(forward, m, J_val=J_val, ddJ=ddJ, seed=5.0e-3,
min_order = taylor_test(forward, m, J_val=J_val, ddJ=ddJ, seed=1.0e-3,
dM=dm)
assert min_order > 2.98
assert min_order > 2.97

min_order = taylor_test_tlm(forward, m, tlm_order=1, seed=5.0e-3,
min_order = taylor_test_tlm(forward, m, tlm_order=1, seed=1.0e-4,
dMs=(dm,))
assert min_order > 1.99

min_order = taylor_test_tlm_adjoint(forward, m, adjoint_order=1,
seed=5.0e-3, dMs=(dm,))
seed=1.0e-4, dMs=(dm,))
assert min_order > 1.99

min_order = taylor_test_tlm_adjoint(forward, m, adjoint_order=2,
seed=5.0e-3)
assert min_order > 1.98
seed=1.0e-4)
assert min_order > 1.99


def test_DirichletBC_overlap(setup_test, test_leaks):
Expand Down
2 changes: 1 addition & 1 deletion tests/firedrake/test_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def forward(m):

ddJ = Hessian(forward)
min_order = taylor_test(forward, m, J_val=J_val, ddJ=ddJ, seed=1.0e-4,
size=3)
size=2)
assert min_order > 2.99

min_order = taylor_test_tlm(forward, m, tlm_order=1, seed=1.0e-4)
Expand Down
17 changes: 13 additions & 4 deletions tlm_adjoint/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
This module defines a default communicator `DEFAULT_COMM`.
"""

from .alias import gc_disabled
from .manager import manager_disabled

from collections.abc import MutableMapping, Sequence
Expand Down Expand Up @@ -297,7 +298,6 @@ def finalize_callback(comm_py2f, key, dup_comm):
_parent_comms.pop(dup_comm.py2f(), None)
_dupped_comms.pop(comm_py2f, None)
_dup_comms.pop(key, None)
garbage_cleanup(dup_comm)
if MPI is not None and not MPI.Is_finalized():
dup_comm.Free()

Expand All @@ -323,6 +323,7 @@ def garbage_cleanup_base(comm):
register_garbage_cleanup(garbage_cleanup_base)


@gc_disabled
def garbage_cleanup(comm=None):
"""Call `petsc4py.PETSc.garbage_cleanup(comm)` for a communicator, and any
communicators duplicated from it using :func:`.comm_dup_cached`.
Expand All @@ -348,9 +349,17 @@ def garbage_cleanup(comm=None):
comms[comm.py2f()] = comm
comm_stack.extend(_dupped_comms.get(comm.py2f(), {}).values())

for comm in comms.values():
for fn in _garbage_cleanup:
fn(comm)
if PETSc is not None:
petsc_comms = tuple(PETSc.Comm(comm).duplicate()
for comm in comms.values())
try:
for comm in comms.values():
for fn in _garbage_cleanup:
fn(comm)
finally:
if PETSc is not None:
for comm in petsc_comms:
comm.destroy()


def weakref_method(fn, obj):
Expand Down
1 change: 1 addition & 0 deletions tlm_adjoint/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def duplicated_comm(comm):
try:
yield dup_comm
finally:
garbage_cleanup(dup_comm)
dup_comm.Free()


Expand Down

0 comments on commit 4bcc3ff

Please sign in to comment.