Skip to content

Commit

Permalink
Merge pull request #567 from tlm-adjoint/jrmaddison/enlist
Browse files Browse the repository at this point in the history
Add `Packed` class, based on pyadjoint `Enlist` functionality
  • Loading branch information
jrmaddison authored May 28, 2024
2 parents 29b0ad6 + a49bf1e commit 94afb49
Show file tree
Hide file tree
Showing 25 changed files with 347 additions and 443 deletions.
42 changes: 21 additions & 21 deletions tlm_adjoint/block_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@
"""

from .interface import (
comm_dup_cached, space_comm, space_default_space_type, space_eq, space_new,
var_assign, var_locked, var_zero)
Packed, comm_dup_cached, packed, space_comm, space_default_space_type,
space_eq, space_new, var_assign, var_locked, var_zero)
from .manager import manager_disabled
from .petsc import (
PETScOptions, PETScVec, PETScVecInterface, attach_destroy_finalizer)
Expand Down Expand Up @@ -113,20 +113,23 @@ def expand(e):

q = deque(map(expand, iterable))
while len(q) > 0:
e = q.popleft()
if isinstance(e, Sequence) and not isinstance(e, str):
q.extendleft(map(expand, reversed(e)))
else:
e = Packed(q.popleft())
if e.is_packed:
e, = e
yield e
else:
q.extendleft(map(expand, reversed(e)))


def tuple_sub(iterable, sequence):
iterator = iter_sub(iterable)

def tuple_sub(iterator, value):
if isinstance(value, Sequence) and not isinstance(value, str):
value = Packed(value)
if value.is_packed:
return next(iterator)
else:
return tuple(tuple_sub(iterator, e) for e in value)
return next(iterator)

t = tuple_sub(iterator, sequence)

Expand Down Expand Up @@ -252,10 +255,8 @@ class MixedSpace(PETScVecInterface, Sequence):
def __init__(self, spaces):
if isinstance(spaces, MixedSpace):
spaces = spaces.split_space
elif isinstance(spaces, Sequence):
spaces = tuple(spaces)
else:
spaces = (spaces,)
spaces = packed(spaces)
flattened_spaces = tuple(space if isinstance(space, TypedSpace)
else TypedSpace(space)
for space in iter_sub(spaces))
Expand Down Expand Up @@ -489,10 +490,7 @@ class BlockNullspace(Nullspace, Sequence):
"""

def __init__(self, nullspaces):
if not isinstance(nullspaces, Sequence):
nullspaces = (nullspaces,)

nullspaces = list(nullspaces)
nullspaces = list(packed(nullspaces))
for i, nullspace in enumerate(nullspaces):
if nullspace is None:
nullspaces[i] = NoneNullspace()
Expand All @@ -502,8 +500,7 @@ def __init__(self, nullspaces):
self._nullspaces = nullspaces

def __new__(cls, nullspaces, *args, **kwargs):
if not isinstance(nullspaces, Sequence):
nullspaces = (nullspaces,)
nullspaces = packed(nullspaces)
for nullspace in nullspaces:
if nullspace is not None \
and not isinstance(nullspace, NoneNullspace):
Expand Down Expand Up @@ -862,17 +859,20 @@ def solve(self, u, b, *,
the solution.
"""

u_packed = Packed(u)
b_packed = Packed(b)
u = tuple(u_packed)
b = tuple(b_packed)

pc_fn = self._pc_fn
if not isinstance(u, Sequence):
u = (u,)
if u_packed.is_packed:
pc_fn_u = pc_fn

def pc_fn(u, b):
u, = tuple(iter_sub(u))
pc_fn_u(u, b)

if not isinstance(b, Sequence):
b = (b,)
if b_packed.is_packed:
pc_fn_b = pc_fn

def pc_fn(u, b):
Expand Down
29 changes: 13 additions & 16 deletions tlm_adjoint/cached_hessian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .interface import (
VariableStateLockDictionary, is_var, var_check_state_lock, var_id,
Packed, VariableStateLockDictionary, packed, var_check_state_lock, var_id,
var_increment_state_lock, var_new, var_scalar_value)

from .caches import clear_caches
Expand Down Expand Up @@ -152,11 +152,10 @@ def __init__(self, J, *, manager=None, cache_adjoint=True):

@restore_manager
def compute_gradient(self, M, M0=None):
if is_var(M):
J_val, (dJ,) = self.compute_gradient(
(M,),
M0=None if M0 is None else (M0,))
return J_val, dJ
M_packed = Packed(M)
M = tuple(M_packed)
if M0 is not None:
M0 = packed(M0)

var_check_state_lock(self._J)

Expand All @@ -173,15 +172,15 @@ def compute_gradient(self, M, M0=None):
store_adjoint=self._cache_adjoint)

reset_manager()
return J_val, dJ
return J_val, M_packed.unpack(dJ)

@restore_manager
def action(self, M, dM, M0=None):
if is_var(M):
J_val, dJ_val, (ddJ,) = self.action(
(M,), (dM,),
M0=None if M0 is None else (M0,))
return J_val, dJ_val, ddJ
M_packed = Packed(M)
M = tuple(M_packed)
dM = packed(dM)
if M0 is not None:
M0 = packed(M0)

var_check_state_lock(self._J)

Expand All @@ -198,7 +197,7 @@ def action(self, M, dM, M0=None):
store_adjoint=self._cache_adjoint)

reset_manager()
return J_val, dJ_val, ddJ
return J_val, dJ_val, M_packed.unpack(ddJ)


class CachedGaussNewton(GaussNewton, HessianOptimization):
Expand All @@ -215,9 +214,7 @@ class CachedGaussNewton(GaussNewton, HessianOptimization):

def __init__(self, X, R_inv_action, B_inv_action=None, *,
manager=None):
if is_var(X):
X = (X,)

X = packed(X)
for x in X:
var_increment_state_lock(x, self)

Expand Down
77 changes: 32 additions & 45 deletions tlm_adjoint/equation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .interface import (
check_space_types, is_var, var_id, var_is_alias, var_is_static, var_locked,
var_new, var_replacement, var_update_caches, var_update_state, var_zero)
Packed, check_space_types, is_var, packed, var_id, var_is_alias,
var_is_static, var_locked, var_new, var_replacement, var_update_caches,
var_update_state, var_zero)

from .alias import WeakAlias, gc_disabled
from .manager import manager as _manager
from .manager import annotation_enabled, paused_manager, tlm_enabled

from collections.abc import Sequence
import functools
import inspect
import itertools
Expand Down Expand Up @@ -128,8 +128,8 @@ def __init__(self, X, deps, nl_deps=None, *,
ic_deps=None, ic=None,
adj_ic_deps=None, adj_ic=None,
adj_type="conjugate_dual"):
if is_var(X):
X = (X,)
X_packed = Packed(X)
X = tuple(X_packed)
X_ids = set(map(var_id, X))
dep_ids = {var_id(dep): i for i, dep in enumerate(deps)}
for x in X:
Expand Down Expand Up @@ -193,16 +193,14 @@ def __init__(self, X, deps, nl_deps=None, *,

if adj_type in ["primal", "conjugate_dual"]:
adj_type = tuple(adj_type for _ in X)
elif isinstance(adj_type, Sequence):
if len(adj_type) != len(X):
raise ValueError("Invalid adjoint type")
else:
if len(adj_type) != len(X):
raise ValueError("Invalid adjoint type")
for adj_x_type in adj_type:
if adj_x_type not in {"primal", "conjugate_dual"}:
raise ValueError("Invalid adjoint type")

super().__init__()
self._packed = X_packed.mapped(lambda x: None)
self._X = tuple(X)
self._deps = tuple(deps)
self._nl_deps = tuple(nl_deps)
Expand Down Expand Up @@ -254,6 +252,9 @@ def x(self):
x, = self._X
return x

def _unpack(self, obj):
return self._packed.unpack(obj)

def X(self, m=None):
"""Return forward solution variables.
Expand Down Expand Up @@ -397,7 +398,7 @@ def forward(self, X, deps=None):
with var_locked(*(dep for dep in (eq_deps if deps is None else deps)
if var_id(dep) not in X_ids)):
var_update_caches(*eq_deps, value=deps)
self.forward_solve(X[0] if len(X) == 1 else X, deps=deps)
self.forward_solve(self._unpack(X), deps=deps)
var_update_state(*X)
var_update_caches(*self.X(), value=X)

Expand All @@ -407,11 +408,9 @@ def forward_solve(self, X, deps=None):
Can assume that the currently active :class:`.EquationManager` is
paused.
:arg X: A variable if the forward solution has a single component,
otherwise a :class:`Sequence` of variables. May define an initial
guess, and should be set by this method. Subclasses may replace
this argument with `x` if the forward solution has a single
component.
:arg X: A variable or a :class:`Sequence` of variables storing the
solution. May define an initial guess, and should be set by this
method.
:arg deps: A :class:`tuple` of variables, defining values for
dependencies. Only the elements corresponding to `X` may be
modified. `self.dependencies()` should be used if not supplied.
Expand All @@ -428,8 +427,8 @@ def adjoint(self, adj_X, nl_deps, B, dep_Bs):
returned.
:arg nl_deps: A :class:`Sequence` of variables defining values for
non-linear dependencies. Should not be modified.
:arg B: A sequence of variables defining the right-hand-side of the
adjoint equation. May be modified or returned.
:arg B: A :class:`Sequence` of variables defining the right-hand-side
of the adjoint equation. May be modified or returned.
:arg dep_Bs: A :class:`Mapping` whose items are `(dep_index, dep_B)`.
Each `dep_B` is an :class:`.AdjointRHS` which should be updated by
subtracting adjoint derivative information computed by
Expand All @@ -443,20 +442,19 @@ def adjoint(self, adj_X, nl_deps, B, dep_Bs):
var_update_caches(*self.nonlinear_dependencies(), value=nl_deps)

adj_X = self.adjoint_jacobian_solve(
adj_X if adj_X is None or len(adj_X) != 1 else adj_X[0],
nl_deps, B[0] if len(B) == 1 else B)
None if adj_X is None else self._unpack(adj_X),
nl_deps, self._unpack(B))
if adj_X is None:
return None
elif is_var(adj_X):
adj_X = (adj_X,)
adj_X = packed(adj_X)
var_update_state(*adj_X)

for m, adj_x in enumerate(adj_X):
check_space_types(adj_x, self.X(m),
rel_space_type=self.adj_X_type(m))

self.subtract_adjoint_derivative_actions(
adj_X[0] if len(adj_X) == 1 else adj_X, nl_deps, dep_Bs)
self._unpack(adj_X), nl_deps, dep_Bs)

return tuple(adj_X)

Expand All @@ -477,7 +475,7 @@ def adjoint_cached(self, adj_X, nl_deps, dep_Bs):
var_update_caches(*self.nonlinear_dependencies(), value=nl_deps)

self.subtract_adjoint_derivative_actions(
adj_X[0] if len(adj_X) == 1 else adj_X, nl_deps, dep_Bs)
self._unpack(adj_X), nl_deps, dep_Bs)

def adjoint_derivative_action(self, nl_deps, dep_index, adj_X):
"""Return the action of the adjoint of a derivative of the forward
Expand All @@ -489,10 +487,8 @@ def adjoint_derivative_action(self, nl_deps, dep_index, adj_X):
:arg dep_index: An :class:`int`. The derivative is defined by
differentiation of the forward residual with respect to
`self.dependencies()[dep_index]`.
:arg adj_X: The adjoint solution. A variable if the adjoint solution
has a single component, otherwise a :class:`Sequence` of variables.
Should not be modified. Subclasses may replace this argument with
`adj_x` if the adjoint solution has a single component.
:arg adj_X: The adjoint solution. A variable or a :class:`Sequence` of
variables. Should not be modified.
:returns: The action of the adjoint of a derivative on the adjoint
solution. Will be passed to
:func:`.subtract_adjoint_derivative_action`, and valid types depend
Expand All @@ -510,10 +506,8 @@ def subtract_adjoint_derivative_actions(self, adj_X, nl_deps, dep_Bs):
Can be overridden for an optimized implementation, but otherwise uses
:meth:`.Equation.adjoint_derivative_action`.
:arg adj_X: The adjoint solution. A variable if the adjoint solution
has a single component, otherwise a :class:`Sequence` of variables.
Should not be modified. Subclasses may replace this argument with
`adj_x` if the adjoint solution has a single component.
:arg adj_X: The adjoint solution. A variable or a :class:`Sequence` of
variables. Should not be modified.
:arg nl_deps: A :class:`Sequence` of variables defining values for
non-linear dependencies. Should not be modified.
:arg dep_Bs: A :class:`Mapping` whose items are `(dep_index, dep_B)`.
Expand All @@ -529,18 +523,14 @@ def subtract_adjoint_derivative_actions(self, adj_X, nl_deps, dep_Bs):
def adjoint_jacobian_solve(self, adj_X, nl_deps, B):
"""Compute an adjoint solution.
:arg adj_X: Either `None`, or a variable (if the adjoint solution has a
single component) or :class:`Sequence` of variables (otherwise)
defining the initial guess for an iterative solve. May be modified
or returned. Subclasses may replace this argument with `adj_x` if
the adjoint solution has a single component.
:arg adj_X: Either `None`, or a variable or :class:`Sequence` of
variables defining the initial guess for an iterative solve. May be
modified or returned.
:arg nl_deps: A :class:`Sequence` of variables defining values for
non-linear dependencies. Should not be modified.
:arg B: The right-hand-side. A variable (if the adjoint solution has a
single component) or :class:`Sequence` of variables (otherwise)
storing the value of the right-hand-side. May be modified or
returned. Subclasses may replace this argument with `b` if the
adjoint solution has a single component.
:arg B: The right-hand-side. A variable or :class:`Sequence` of
variables storing the value of the right-hand-side. May be modified
or returned.
:returns: A variable or :class:`Sequence` of variables storing the
value of the adjoint solution. May return `None` to indicate a
value of zero.
Expand Down Expand Up @@ -579,13 +569,10 @@ class ZeroAssignment(Equation):
"""

def __init__(self, X):
if is_var(X):
X = (X,)
X = packed(X)
super().__init__(X, X, nl_deps=[], ic=False, adj_ic=False)

def forward_solve(self, X, deps=None):
if is_var(X):
X = (X,)
for x in X:
var_zero(x)

Expand Down
Loading

0 comments on commit 94afb49

Please sign in to comment.