Skip to content

Commit

Permalink
Merge pull request #579 from tlm-adjoint/jrmaddison/torch_example
Browse files Browse the repository at this point in the history
Torch interface updates
  • Loading branch information
jrmaddison authored Jun 13, 2024
2 parents 6c27141 + 7f98741 commit 8cbcaf3
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 44 deletions.
5 changes: 3 additions & 2 deletions tests/base/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_torch_wrapped(setup_test, # noqa: F811
def forward(m):
return Float(m)

_, _, x_t = torch_wrapped(forward, m)
x_t = torch_wrapped(forward, m.space)(*to_torch_tensors(m))
from_torch_tensors(x, x_t)

assert x is not m
Expand All @@ -74,7 +74,8 @@ def forward(m):
return m ** 4

J_ref = complex(m) ** 4
_, forward_t, J_t = torch_wrapped(forward, m)
forward_t = torch_wrapped(forward, m.space)
J_t = forward_t(*to_torch_tensors(m))
from_torch_tensors(J, J_t)
assert abs(complex(J) - complex(J_ref)) < 1.0e-15

Expand Down
5 changes: 3 additions & 2 deletions tests/firedrake/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_torch_wrapped(setup_test, test_leaks):
def forward(m):
return m.copy(deepcopy=True)

_, _, x_t = torch_wrapped(forward, m)
x_t = torch_wrapped(forward, m.function_space())(*to_torch_tensors(m))
from_torch_tensors(x, x_t)

err = var_copy(x)
Expand Down Expand Up @@ -79,7 +79,8 @@ def forward(m):
return J

J_ref = assemble((m ** 4) * dx)
_, forward_t, J_t = torch_wrapped(forward, m)
forward_t = torch_wrapped(forward, m.function_space())
J_t = forward_t(*to_torch_tensors(m))
from_torch_tensors(J, J_t)
assert abs(complex(J) - complex(J_ref)) == 0.0

Expand Down
2 changes: 1 addition & 1 deletion tlm_adjoint/firedrake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .backend_interface import *
from .assembly import *
from .assignment import *
from .block_system import ConstantNullspace, DirichletBCNullspace, UnityNullspace
from .block_system import ConstantNullspace, DirichletBCNullspace, UnityNullspace, WhiteNoiseSampler
from .caches import *
from .expr import *
from .interpolation import *
Expand Down
103 changes: 64 additions & 39 deletions tlm_adjoint/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
coupling PyTorch and Firedrake', 2023, arXiv:2303.06871v3
"""

from .caches import clear_caches
from .caches import clear_caches as _clear_caches
from .interface import (
packed, var_comm, var_dtype, var_get_values, var_id, var_locked, var_new,
var_new_conjugate_dual, var_set_values)
Packed, packed, space_new, var_comm, var_dtype, var_get_values, var_id,
var_locked, var_new_conjugate_dual, var_set_values)
from .manager import (
compute_gradient, manager as _manager, reset_manager, restore_manager,
set_manager, start_manager, stop_manager)
Expand All @@ -37,10 +37,19 @@ def to_torch_tensor(x, *args, **kwargs):
def to_torch_tensors(X, *args, **kwargs):
"""Convert one or more variables to :class:`torch.Tensor` objects.
:arg X: A variable or :class:`Sequence` or variables.
:returns: A :class:`Sequence` of :class:`torch.Tensor` objects.
Parameters
----------
Remaining arguments are passed to :func:`torch.tensor`.
X : variable or Sequence[variable]
Variables to be converted
args, kwargs
Passed to :func:`torch.tensor`.
Returns
-------
tuple[variable]
The converted variables.
"""

return tuple(to_torch_tensor(x, *args, **kwargs) for x in packed(X))
Expand All @@ -54,8 +63,13 @@ def from_torch_tensor(x, x_t):
def from_torch_tensors(X, X_t):
"""Copy data from PyTorch tensors into variables.
:arg X: A variable or :class:`Sequence` or variables.
:arg X_t: A :class:`Sequence` of :class:`torch.Tensor` objects.
Parameters
----------
X : variable or Sequence[variable]
Output.
X_t : Sequence[:class:`torch.Tensor`]
Input.
"""

X = packed(X)
Expand All @@ -66,78 +80,89 @@ def from_torch_tensors(X, X_t):


@restore_manager
def _forward(forward, M, manager):
def _forward(forward, M, manager, *, clear_caches=False):
set_manager(manager)
reset_manager()
clear_caches()
if clear_caches:
_clear_caches()

start_manager()
with var_locked(*M):
X = packed(forward(*M))
X = forward(*M)
X_packed = Packed(X)
X = tuple(X_packed)
J = Float(dtype=var_dtype(X[0]), comm=var_comm(X[0]))
adj_X = tuple(map(var_new_conjugate_dual, X))
AdjointActionMarker(J, X, adj_X).solve()
stop_manager()

return X, J, adj_X
return X_packed.unpack(X), J, X_packed.unpack(adj_X)


class TorchInterface(object if torch is None else torch.autograd.Function):
@staticmethod
def forward(ctx, forward, manager, J_id, M, *M_t):
M = tuple(map(var_new, M))
def forward(ctx, forward, manager, clear_caches, J_id, space, *M_t):
M = tuple(map(space_new, space))
from_torch_tensors(M, M_t)

X, J, adj_X = _forward(forward, M, manager)
X, J, adj_X = _forward(forward, M, manager,
clear_caches=clear_caches)

J_id[0] = var_id(J)
ctx._tlm_adjoint__output_ctx = (forward, manager, J_id, M, J, adj_X)
ctx._tlm_adjoint__output_ctx = (forward, manager, clear_caches,
J_id, M, J, adj_X)
return to_torch_tensors(X)

@staticmethod
@restore_manager
def backward(ctx, *adj_X_t):
forward, manager, J_id, M, J, adj_X = ctx._tlm_adjoint__output_ctx
(forward, manager, clear_caches,
J_id, M, J, adj_X) = ctx._tlm_adjoint__output_ctx
if var_id(J) != J_id[0] or manager._cp_schedule.is_exhausted:
_, J, adj_X = _forward(forward, M, manager)
_, J, adj_X = _forward(forward, M, manager,
clear_caches=clear_caches)
J_id[0] = var_id(J)

from_torch_tensors(adj_X, adj_X_t)
set_manager(manager)
dJ = compute_gradient(J, M)

return (None, None, None, None) + to_torch_tensors(dJ)
return (None, None, None, None, None) + to_torch_tensors(dJ)


def torch_wrapped(forward, M, *, manager=None):
def torch_wrapped(forward, space, *, manager=None, clear_caches=True):
"""Wrap a model, differentiated using tlm_adjoint, so that it can be used
with PyTorch.
:arg forward: A callable which accepts one or more variable arguments, and
returns a variable or :class:`Sequence` of variables.
:arg M: A variable or :class:`Sequence` of variables defining the input to
`forward`.
:arg manager: An :class:`.EquationManager` used to create an internal
manager via :meth:`.EquationManager.new`. `manager()` is used if not
supplied.
:returns: A :class:`tuple` `(M_t, forward_t, X_t)`, where
- `M_t` is a :class:`torch.Tensor` storing the value of `M`.
- `forward_t` is a version of `forward` with :class:`torch.Tensor`
inputs and outputs.
- `X_t` is a :class:`torch.Tensor` containing the value of
`forward` evaluated with `M` as input.
Parameters
----------
forward : callable
Accepts one or more variable arguments, and returns a variable or
:class:`Sequence` of variables.
space : space or Sequence[space]
Defines the spaces for input arguments.
manager : :class:`.EquationManager`
Used to create an internal manager via :meth:`.EquationManager.new`.
`manager()` is used if not supplied.
clear_caches : Whether to clear caches before a call of `forward`.
Returns
-------
callable
A version of `forward` with :class:`torch.Tensor` inputs and outputs.
"""

M = packed(M)
space = packed(space)
if manager is None:
manager = _manager()
manager = manager.new()
J_id = [None]

M_t = to_torch_tensors(M, requires_grad=True)
J_id = [None]

def forward_t(*M_t):
return TorchInterface.apply(forward, manager, J_id, M, *M_t)
return TorchInterface.apply(
forward, manager, clear_caches, J_id, space, *M_t)

return M_t, forward_t, forward_t(*M_t)
return forward_t

0 comments on commit 8cbcaf3

Please sign in to comment.