Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tidying #569

Merged
merged 7 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test-firedrake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: Install dependencies
run: |
. /home/firedrake/firedrake/bin/activate
python3 -m pip install jax jaxlib ruff pytest-timeout pytest-xdist
python3 -m pip install jax[cpu] ruff pytest-timeout pytest-xdist
- name: Lint
run: |
. /home/firedrake/firedrake/bin/activate
Expand Down Expand Up @@ -59,7 +59,7 @@ jobs:
- name: Install dependencies
run: |
. /home/firedrake/firedrake/bin/activate
python3 -m pip install jax jaxlib ruff pytest-timeout pytest-xdist
python3 -m pip install jax[cpu] ruff pytest-timeout pytest-xdist
- name: Lint
run: |
. /home/firedrake/firedrake/bin/activate
Expand Down
10 changes: 5 additions & 5 deletions tests/base/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,20 @@ class Test:
t = Test()
assert not var_state_is_locked(f)
for _ in range(10):
var_increment_state_lock(f, t)
var_increment_state_lock(t, f)
assert var_state_is_locked(f)
# ... then decrement 10 times with the same object
for _ in range(9):
var_decrement_state_lock(f, t)
var_decrement_state_lock(t, f)
assert var_state_is_locked(f)
var_decrement_state_lock(f, t)
var_decrement_state_lock(t, f)
assert not var_state_is_locked(f)

# Increment 10 times with the same object ...
t = Test()
assert not var_state_is_locked(f)
for _ in range(10):
var_increment_state_lock(f, t)
var_increment_state_lock(t, f)
assert var_state_is_locked(f)
# ... then destroy the object
del t
Expand All @@ -110,7 +110,7 @@ class Test:
assert not var_state_is_locked(f)
T = [Test() for _ in range(10)]
for t, _ in itertools.product(T, range(10)):
var_increment_state_lock(f, t)
var_increment_state_lock(t, f)
assert var_state_is_locked(f)
# ... then destroy the objects
t = None
Expand Down
5 changes: 2 additions & 3 deletions tlm_adjoint/cached_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class CachedHessian(Hessian, HessianOptimization):
"""

def __init__(self, J, *, manager=None, cache_adjoint=True):
var_increment_state_lock(J, self)
var_increment_state_lock(self, J)

HessianOptimization.__init__(self, manager=manager,
cache_adjoint=cache_adjoint)
Expand Down Expand Up @@ -215,8 +215,7 @@ class CachedGaussNewton(GaussNewton, HessianOptimization):
def __init__(self, X, R_inv_action, B_inv_action=None, *,
manager=None):
X = packed(X)
for x in X:
var_increment_state_lock(x, self)
var_increment_state_lock(self, *X)

HessianOptimization.__init__(self, manager=manager,
cache_adjoint=False)
Expand Down
4 changes: 2 additions & 2 deletions tlm_adjoint/fenics/backend_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def __setitem__(self, key, value):
if id(obj) not in self._keys:
self._keys[id(obj)] = []

def weakref_finalize(obj_id, d, keys):
def finalize_callback(obj_id, d, keys):
for key in keys.pop(obj_id, []):
d.pop(key, None)

weakref.finalize(obj, weakref_finalize,
weakref.finalize(obj, finalize_callback,
id(obj), self._d, self._keys)

self._d[key] = value
Expand Down
19 changes: 15 additions & 4 deletions tlm_adjoint/firedrake/backend_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,26 @@ def new_space_id_cached(space):
return space_ids[key]


def space_local_indices_cached(space, cls):
if not hasattr(space, "_tlm_adjoint__local_indices"):
space._tlm_adjoint__local_indices = {}
local_indices = space._tlm_adjoint__local_indices

# Work around Firedrake issue #3130
key = (space, ufl.duals.is_primal(space))
if key not in local_indices:
with cls(space).dat.vec_ro as x_v:
local_indices[key] = x_v.getOwnershipRange()
return local_indices[key]


@patch_method(backend_FunctionSpace, "__init__")
def FunctionSpace__init__(self, orig, orig_args, *args, **kwargs):
orig_args()
add_interface(self, FunctionSpaceInterface,
{"space": self, "comm": comm_dup_cached(self.comm),
"id": new_space_id_cached(self)})
with backend_Function(self).dat.vec_ro as x_v:
n0, n1 = x_v.getOwnershipRange()
n0, n1 = space_local_indices_cached(self, backend_Function)
self._tlm_adjoint__space_interface_attrs["local_indices"] = (n0, n1)


Expand All @@ -279,8 +291,7 @@ def CofunctionSpace__init__(self, orig, orig_args, *args, **kwargs):
add_interface(self, FunctionSpaceInterface,
{"space_dual": self, "comm": comm_dup_cached(self.comm),
"id": new_space_id_cached(self)})
with backend_Cofunction(self).dat.vec_ro as x_v:
n0, n1 = x_v.getOwnershipRange()
n0, n1 = space_local_indices_cached(self, backend_Cofunction)
self._tlm_adjoint__space_interface_attrs["local_indices"] = (n0, n1)


Expand Down
7 changes: 3 additions & 4 deletions tlm_adjoint/hessian_system.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .interface import (
Packed, packed, var_copy_conjugate, var_increment_state_lock, var_locked,
var_space, vars_assign, vars_assign_conjugate, vars_axpy,
vars_axpy_conjugate, vars_copy_conjugate, vars_inner)
vars_axpy_conjugate, vars_inner)

from .block_system import (
Eigensolver, LinearSolver, Matrix, MatrixFreeMatrix, TypedSpace)
Expand Down Expand Up @@ -35,8 +35,7 @@ def __init__(self, H, M):
self._H = H
self._M = M

for m in M:
var_increment_state_lock(m, self)
var_increment_state_lock(self, *M)

def mult_add(self, x, y):
x = packed(x)
Expand Down Expand Up @@ -120,7 +119,7 @@ def B_action(x, y):
B_inv_action_arg = B_inv_action

def B_inv_action(x, y):
x = vars_copy_conjugate(packed(x))
x = packed(x)
y = packed(y)

with var_locked(*x):
Expand Down
100 changes: 48 additions & 52 deletions tlm_adjoint/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,52 +986,55 @@ def var_state(x):
return x._tlm_adjoint__var_interface_state()


def var_increment_state_lock(x, obj):
if var_is_replacement(x):
raise ValueError("x cannot be a replacement")
var_check_state_lock(x)
x_id = var_id(x)
def var_increment_state_lock(obj, *X):
for x in X:
if var_is_replacement(x):
raise ValueError("x cannot be a replacement")
var_check_state_lock(x)
x_id = var_id(x)

if not hasattr(x, "_tlm_adjoint__state_lock"):
x._tlm_adjoint__state_lock = 0
if x._tlm_adjoint__state_lock == 0:
x._tlm_adjoint__state_lock_state = var_state(x)
if not hasattr(x, "_tlm_adjoint__state_lock"):
x._tlm_adjoint__state_lock = 0
if x._tlm_adjoint__state_lock == 0:
x._tlm_adjoint__state_lock_state = var_state(x)

# Functionally similar to a weakref.WeakKeyDictionary, using the variable
# ID as a key. This approach does not require obj to be hashable.
if not hasattr(obj, "_tlm_adjoint__state_locks"):
obj._tlm_adjoint__state_locks = {}
# Functionally similar to a weakref.WeakKeyDictionary, using the
# variable ID as a key. This approach does not require obj to be
# hashable.
if not hasattr(obj, "_tlm_adjoint__state_locks"):
obj._tlm_adjoint__state_locks = {}

def weakref_finalize(locks):
for x_ref, count in locks.values():
x = x_ref()
if x is not None and hasattr(x, "_tlm_adjoint__state_lock"):
x._tlm_adjoint__state_lock -= count
def finalize_callback(locks):
for x_ref, count in locks.values():
x = x_ref()
if x is not None and hasattr(x, "_tlm_adjoint__state_lock"): # noqa: E501
x._tlm_adjoint__state_lock -= count

weakref.finalize(obj, weakref_finalize,
obj._tlm_adjoint__state_locks)
if x_id not in obj._tlm_adjoint__state_locks:
obj._tlm_adjoint__state_locks[x_id] = [weakref.ref(x), 0]
weakref.finalize(obj, finalize_callback,
obj._tlm_adjoint__state_locks)
if x_id not in obj._tlm_adjoint__state_locks:
obj._tlm_adjoint__state_locks[x_id] = [weakref.ref(x), 0]

x._tlm_adjoint__state_lock += 1
obj._tlm_adjoint__state_locks[x_id][1] += 1
x._tlm_adjoint__state_lock += 1
obj._tlm_adjoint__state_locks[x_id][1] += 1


def var_decrement_state_lock(x, obj):
if var_is_replacement(x):
raise ValueError("x cannot be a replacement")
var_check_state_lock(x)
x_id = var_id(x)
def var_decrement_state_lock(obj, *X):
for x in X:
if var_is_replacement(x):
raise ValueError("x cannot be a replacement")
var_check_state_lock(x)
x_id = var_id(x)

if x._tlm_adjoint__state_lock < obj._tlm_adjoint__state_locks[x_id][1]:
raise RuntimeError("Invalid state lock")
if obj._tlm_adjoint__state_locks[x_id][1] < 1:
raise RuntimeError("Invalid state lock")
if x._tlm_adjoint__state_lock < obj._tlm_adjoint__state_locks[x_id][1]:
raise RuntimeError("Invalid state lock")
if obj._tlm_adjoint__state_locks[x_id][1] < 1:
raise RuntimeError("Invalid state lock")

x._tlm_adjoint__state_lock -= 1
obj._tlm_adjoint__state_locks[x_id][1] -= 1
if obj._tlm_adjoint__state_locks[x_id][1] == 0:
del obj._tlm_adjoint__state_locks[x_id]
x._tlm_adjoint__state_lock -= 1
obj._tlm_adjoint__state_locks[x_id][1] -= 1
if obj._tlm_adjoint__state_locks[x_id][1] == 0:
del obj._tlm_adjoint__state_locks[x_id]


class VariableStateChangeError(RuntimeError):
Expand All @@ -1048,7 +1051,7 @@ class Lock:
pass

lock = x._tlm_adjoint__state_lock_lock = Lock()
var_increment_state_lock(x, lock)
var_increment_state_lock(lock, x)


def var_state_is_locked(x):
Expand Down Expand Up @@ -1080,7 +1083,7 @@ def __init__(self, *args, **kwargs):
self._d = dict(*args, **kwargs)
for value in self._d.values():
if is_var(value) and not var_is_replacement(value):
var_increment_state_lock(value, self)
var_increment_state_lock(self, value)

def __getitem__(self, key):
value = self._d[key]
Expand All @@ -1092,15 +1095,15 @@ def __setitem__(self, key, value):
oldvalue = self._d.get(key, None)
self._d[key] = value
if is_var(value) and not var_is_replacement(value):
var_increment_state_lock(value, self)
var_increment_state_lock(self, value)
if is_var(oldvalue) and not var_is_replacement(oldvalue):
var_decrement_state_lock(oldvalue, self)
var_decrement_state_lock(self, oldvalue)

def __delitem__(self, key):
oldvalue = self._d.get(key, None)
del self._d[key]
if is_var(oldvalue) and not var_is_replacement(oldvalue):
var_decrement_state_lock(oldvalue, self)
var_decrement_state_lock(self, oldvalue)

def __iter__(self):
yield from self._d
Expand All @@ -1123,14 +1126,12 @@ class Lock:
pass

lock = Lock()
for x in X:
var_increment_state_lock(x, lock)
var_increment_state_lock(lock, *X)

try:
yield
finally:
for x in X:
var_decrement_state_lock(x, lock)
var_decrement_state_lock(lock, *X)


def var_update_state(*X):
Expand Down Expand Up @@ -1638,16 +1639,11 @@ def __init__(self, obj):
self._is_packed = is_packed

def __eq__(self, other):
other = Packed(other)
return (tuple(self) == tuple(other)
and self.is_packed == other.is_packed)
return tuple(self) == tuple(Packed(other))

def __ne__(self, other):
return not self == other

def __hash__(self):
return hash((tuple(self), self.is_packed))

def __len__(self):
return len(self._t)

Expand Down
8 changes: 4 additions & 4 deletions tlm_adjoint/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def M(*X):
M_X = packed(M_X)
if len(M_X) != len(X):
raise ValueError("Incompatible shape")
return vars_copy(M_X)
return M_X

return M

Expand Down Expand Up @@ -400,7 +400,7 @@ def inverse_action(self, X, *,
alphas.append(alpha)
alphas.reverse()

R = H_0_action(*X)
R = vars_copy(H_0_action(*X))
if theta != 1.0:
for r in R:
var_set_values(r, var_get_values(r) / theta)
Expand Down Expand Up @@ -520,8 +520,8 @@ def objective_gradient(taols, x, g):
options["tao_ls_gtol"] = c2
taols.setOptionsPrefix(options.options_prefix)

taols.setUp()
taols.setFromOptions()
taols.setUp()

x = PETScVec(vec_interface)
x.to_petsc(X)
Expand Down Expand Up @@ -723,7 +723,7 @@ def Fp(*X):
Fp_val = packed(Fp_val)
if len(Fp_val) != len(X):
raise ValueError("Incompatible shape")
return vars_copy(Fp_val)
return Fp_val

X0_packed = Packed(X0)
X0 = tuple(X0_packed)
Expand Down
4 changes: 2 additions & 2 deletions tlm_adjoint/tangent_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,13 @@ def __init__(self, M, dM):
self._dM = dM

@gc_disabled
def weakref_finalize(X, tlm_map_id):
def finalize_callback(X, tlm_map_id):
for x_id in sorted(tuple(X)):
x = X.get(x_id, None)
if x is not None:
getattr(x, "_tlm_adjoint__tangent_linears", {}).pop(tlm_map_id, None) # noqa: E501

weakref.finalize(self, weakref_finalize,
weakref.finalize(self, finalize_callback,
self._X, self._id)

if len(M) == 1:
Expand Down