From f23933d33ba49290400ead30dc12708763e0f1ae Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Wed, 5 Jun 2024 09:57:15 +0100 Subject: [PATCH 1/6] numpydoc --- tlm_adjoint/torch.py | 59 ++++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/tlm_adjoint/torch.py b/tlm_adjoint/torch.py index d8f6ed24..88f1e32f 100644 --- a/tlm_adjoint/torch.py +++ b/tlm_adjoint/torch.py @@ -37,10 +37,19 @@ def to_torch_tensor(x, *args, **kwargs): def to_torch_tensors(X, *args, **kwargs): """Convert one or more variables to :class:`torch.Tensor` objects. - :arg X: A variable or :class:`Sequence` or variables. - :returns: A :class:`Sequence` of :class:`torch.Tensor` objects. + Parameters + ---------- - Remaining arguments are passed to :func:`torch.tensor`. + X : variable or Sequence[variable] + Variables to be converted + args, kwargs + Passed to :func:`torch.tensor`. + + Returns + ------- + + tuple[variable] + The converted variables. """ return tuple(to_torch_tensor(x, *args, **kwargs) for x in packed(X)) @@ -54,8 +63,13 @@ def from_torch_tensor(x, x_t): def from_torch_tensors(X, X_t): """Copy data from PyTorch tensors into variables. - :arg X: A variable or :class:`Sequence` or variables. - :arg X_t: A :class:`Sequence` of :class:`torch.Tensor` objects. + Parameters + ---------- + + X : variable or Sequence[variable] + Output. + X_t : Sequence[:class:`torch.Tensor`] + Input. """ X = packed(X) @@ -113,20 +127,27 @@ def torch_wrapped(forward, M, *, manager=None): """Wrap a model, differentiated using tlm_adjoint, so that it can be used with PyTorch. - :arg forward: A callable which accepts one or more variable arguments, and - returns a variable or :class:`Sequence` of variables. - :arg M: A variable or :class:`Sequence` of variables defining the input to - `forward`. - :arg manager: An :class:`.EquationManager` used to create an internal - manager via :meth:`.EquationManager.new`. `manager()` is used if not - supplied. - :returns: A :class:`tuple` `(M_t, forward_t, X_t)`, where - - - `M_t` is a :class:`torch.Tensor` storing the value of `M`. - - `forward_t` is a version of `forward` with :class:`torch.Tensor` - inputs and outputs. - - `X_t` is a :class:`torch.Tensor` containing the value of - `forward` evaluated with `M` as input. + Parameters + ---------- + + forward : callable + Accepts one or more variable arguments, and returns a variable or + :class:`Sequence` of variables. + M : variable or Sequence[variable] + Defines the input to `forward`. + manager : :class:`.EquationManager` + Used to create an internal manager via :meth:`.EquationManager.new`. + `manager()` is used if not supplied. + + Returns + ------- + + (M_t, forward_t, X_t) : tuple[variable or Sequence[variable], callable, \ + tuple[:class:`torch.Tensor`]] + `M_t` stores the value of `M`. `forward_t` is a version of `forward` + with :class:`torch.Tensor` inputs and outputs. `X_t` is a + :class:`torch.Tensor` containing the value of `forward` evaluated with + `M` as input. """ M = packed(M) From 49c273812e658f094e9b47dbca3b7ca65971d8a0 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Wed, 5 Jun 2024 18:10:15 +0100 Subject: [PATCH 2/6] Torch interface updates --- tests/base/test_torch.py | 5 +++-- tests/firedrake/test_torch.py | 5 +++-- tlm_adjoint/torch.py | 38 ++++++++++++++++------------------- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/tests/base/test_torch.py b/tests/base/test_torch.py index 202c61fe..aceb1372 100644 --- a/tests/base/test_torch.py +++ b/tests/base/test_torch.py @@ -49,7 +49,7 @@ def test_torch_wrapped(setup_test, # noqa: F811 def forward(m): return Float(m) - _, _, x_t = torch_wrapped(forward, m) + x_t = torch_wrapped(forward, m.space)(*to_torch_tensors(m)) from_torch_tensors(x, x_t) assert x is not m @@ -74,7 +74,8 @@ def forward(m): return m ** 4 J_ref = complex(m) ** 4 - _, forward_t, J_t = torch_wrapped(forward, m) + forward_t = torch_wrapped(forward, m.space) + J_t = forward_t(*to_torch_tensors(m)) from_torch_tensors(J, J_t) assert abs(complex(J) - complex(J_ref)) < 1.0e-15 diff --git a/tests/firedrake/test_torch.py b/tests/firedrake/test_torch.py index 643f9bd2..e1c8dc6c 100644 --- a/tests/firedrake/test_torch.py +++ b/tests/firedrake/test_torch.py @@ -51,7 +51,7 @@ def test_torch_wrapped(setup_test, test_leaks): def forward(m): return m.copy(deepcopy=True) - _, _, x_t = torch_wrapped(forward, m) + x_t = torch_wrapped(forward, m.function_space())(*to_torch_tensors(m)) from_torch_tensors(x, x_t) err = var_copy(x) @@ -79,7 +79,8 @@ def forward(m): return J J_ref = assemble((m ** 4) * dx) - _, forward_t, J_t = torch_wrapped(forward, m) + forward_t = torch_wrapped(forward, m.function_space()) + J_t = forward_t(*to_torch_tensors(m)) from_torch_tensors(J, J_t) assert abs(complex(J) - complex(J_ref)) == 0.0 diff --git a/tlm_adjoint/torch.py b/tlm_adjoint/torch.py index 88f1e32f..f025b3af 100644 --- a/tlm_adjoint/torch.py +++ b/tlm_adjoint/torch.py @@ -9,8 +9,8 @@ from .caches import clear_caches from .interface import ( - packed, var_comm, var_dtype, var_get_values, var_id, var_locked, var_new, - var_new_conjugate_dual, var_set_values) + Packed, packed, space_new, var_comm, var_dtype, var_get_values, var_id, + var_locked, var_new_conjugate_dual, var_set_values) from .manager import ( compute_gradient, manager as _manager, reset_manager, restore_manager, set_manager, start_manager, stop_manager) @@ -87,19 +87,21 @@ def _forward(forward, M, manager): start_manager() with var_locked(*M): - X = packed(forward(*M)) + X = forward(*M) + X_packed = Packed(X) + X = tuple(X_packed) J = Float(dtype=var_dtype(X[0]), comm=var_comm(X[0])) adj_X = tuple(map(var_new_conjugate_dual, X)) AdjointActionMarker(J, X, adj_X).solve() stop_manager() - return X, J, adj_X + return X_packed.unpack(X), J, X_packed.unpack(adj_X) class TorchInterface(object if torch is None else torch.autograd.Function): @staticmethod - def forward(ctx, forward, manager, J_id, M, *M_t): - M = tuple(map(var_new, M)) + def forward(ctx, forward, manager, J_id, space, *M_t): + M = tuple(map(space_new, space)) from_torch_tensors(M, M_t) X, J, adj_X = _forward(forward, M, manager) @@ -123,7 +125,7 @@ def backward(ctx, *adj_X_t): return (None, None, None, None) + to_torch_tensors(dJ) -def torch_wrapped(forward, M, *, manager=None): +def torch_wrapped(forward, space, *, manager=None): """Wrap a model, differentiated using tlm_adjoint, so that it can be used with PyTorch. @@ -133,32 +135,26 @@ def torch_wrapped(forward, M, *, manager=None): forward : callable Accepts one or more variable arguments, and returns a variable or :class:`Sequence` of variables. - M : variable or Sequence[variable] - Defines the input to `forward`. - manager : :class:`.EquationManager` Used to create an internal manager via :meth:`.EquationManager.new`. `manager()` is used if not supplied. + space : space or Sequence[space] + Defines the spaces for input arguments. Returns ------- - (M_t, forward_t, X_t) : tuple[variable or Sequence[variable], callable, \ - tuple[:class:`torch.Tensor`]] - `M_t` stores the value of `M`. `forward_t` is a version of `forward` - with :class:`torch.Tensor` inputs and outputs. `X_t` is a - :class:`torch.Tensor` containing the value of `forward` evaluated with - `M` as input. + callable + A version of `forward` with :class:`torch.Tensor` inputs and outputs. """ - M = packed(M) + space = packed(space) if manager is None: manager = _manager() manager = manager.new() - J_id = [None] - M_t = to_torch_tensors(M, requires_grad=True) + J_id = [None] def forward_t(*M_t): - return TorchInterface.apply(forward, manager, J_id, M, *M_t) + return TorchInterface.apply(forward, manager, J_id, space, *M_t) - return M_t, forward_t, forward_t(*M_t) + return forward_t From 82eab10870c57db1c9c46f726cdb841d97db91e7 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Thu, 6 Jun 2024 11:42:39 +0100 Subject: [PATCH 3/6] Add option to disable use of clear_caches --- tlm_adjoint/torch.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/tlm_adjoint/torch.py b/tlm_adjoint/torch.py index f025b3af..2a5087f8 100644 --- a/tlm_adjoint/torch.py +++ b/tlm_adjoint/torch.py @@ -80,10 +80,11 @@ def from_torch_tensors(X, X_t): @restore_manager -def _forward(forward, M, manager): +def _forward(forward, M, manager, *, clear_caches=False): set_manager(manager) reset_manager() - clear_caches() + if clear_caches: + globals()["clear_caches"]() start_manager() with var_locked(*M): @@ -100,32 +101,36 @@ def _forward(forward, M, manager): class TorchInterface(object if torch is None else torch.autograd.Function): @staticmethod - def forward(ctx, forward, manager, J_id, space, *M_t): + def forward(ctx, forward, manager, clear_caches, J_id, space, *M_t): M = tuple(map(space_new, space)) from_torch_tensors(M, M_t) - X, J, adj_X = _forward(forward, M, manager) + X, J, adj_X = _forward(forward, M, manager, + clear_caches=clear_caches) J_id[0] = var_id(J) - ctx._tlm_adjoint__output_ctx = (forward, manager, J_id, M, J, adj_X) + ctx._tlm_adjoint__output_ctx = (forward, manager, clear_caches, + J_id, M, J, adj_X) return to_torch_tensors(X) @staticmethod @restore_manager def backward(ctx, *adj_X_t): - forward, manager, J_id, M, J, adj_X = ctx._tlm_adjoint__output_ctx + (forward, manager, clear_caches, + J_id, M, J, adj_X) = ctx._tlm_adjoint__output_ctx if var_id(J) != J_id[0] or manager._cp_schedule.is_exhausted: - _, J, adj_X = _forward(forward, M, manager) + _, J, adj_X = _forward(forward, M, manager, + clear_caches=clear_caches) J_id[0] = var_id(J) from_torch_tensors(adj_X, adj_X_t) set_manager(manager) dJ = compute_gradient(J, M) - return (None, None, None, None) + to_torch_tensors(dJ) + return (None, None, None, None, None) + to_torch_tensors(dJ) -def torch_wrapped(forward, space, *, manager=None): +def torch_wrapped(forward, space, *, manager=None, clear_caches=True): """Wrap a model, differentiated using tlm_adjoint, so that it can be used with PyTorch. @@ -139,6 +144,7 @@ def torch_wrapped(forward, space, *, manager=None): `manager()` is used if not supplied. space : space or Sequence[space] Defines the spaces for input arguments. + clear_caches : Whether to clear caches before a call of `forward`. Returns ------- @@ -155,6 +161,7 @@ def torch_wrapped(forward, space, *, manager=None): J_id = [None] def forward_t(*M_t): - return TorchInterface.apply(forward, manager, J_id, space, *M_t) + return TorchInterface.apply( + forward, manager, clear_caches, J_id, space, *M_t) return forward_t From a4ffaab54ab55dace314646d8d07c61b8169d682 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Thu, 6 Jun 2024 11:43:04 +0100 Subject: [PATCH 4/6] Add WhiteNoiseSampler to tlm_adjoint.firedrake --- tlm_adjoint/firedrake/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tlm_adjoint/firedrake/__init__.py b/tlm_adjoint/firedrake/__init__.py index 631f3df9..23a53adc 100644 --- a/tlm_adjoint/firedrake/__init__.py +++ b/tlm_adjoint/firedrake/__init__.py @@ -11,7 +11,7 @@ from .backend_interface import * from .assembly import * from .assignment import * -from .block_system import ConstantNullspace, DirichletBCNullspace, UnityNullspace +from .block_system import ConstantNullspace, DirichletBCNullspace, UnityNullspace, WhiteNoiseSampler from .caches import * from .expr import * from .interpolation import * From ae7a1fb051e732b7b79399cb4cea5dbd86648b3c Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Thu, 6 Jun 2024 13:37:28 +0100 Subject: [PATCH 5/6] flake8 --- tlm_adjoint/torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tlm_adjoint/torch.py b/tlm_adjoint/torch.py index 2a5087f8..ca4a2335 100644 --- a/tlm_adjoint/torch.py +++ b/tlm_adjoint/torch.py @@ -7,7 +7,7 @@ coupling PyTorch and Firedrake', 2023, arXiv:2303.06871v3 """ -from .caches import clear_caches +from .caches import clear_caches as _clear_caches from .interface import ( Packed, packed, space_new, var_comm, var_dtype, var_get_values, var_id, var_locked, var_new_conjugate_dual, var_set_values) @@ -84,7 +84,7 @@ def _forward(forward, M, manager, *, clear_caches=False): set_manager(manager) reset_manager() if clear_caches: - globals()["clear_caches"]() + _clear_caches() start_manager() with var_locked(*M): From 7f9874198b2bccedad6f8eb7b0202fce35768351 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Thu, 13 Jun 2024 10:01:40 +0100 Subject: [PATCH 6/6] Documentation fix --- tlm_adjoint/torch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tlm_adjoint/torch.py b/tlm_adjoint/torch.py index ca4a2335..1b4c8db0 100644 --- a/tlm_adjoint/torch.py +++ b/tlm_adjoint/torch.py @@ -140,10 +140,11 @@ def torch_wrapped(forward, space, *, manager=None, clear_caches=True): forward : callable Accepts one or more variable arguments, and returns a variable or :class:`Sequence` of variables. - Used to create an internal manager via :meth:`.EquationManager.new`. - `manager()` is used if not supplied. space : space or Sequence[space] Defines the spaces for input arguments. + manager : :class:`.EquationManager` + Used to create an internal manager via :meth:`.EquationManager.new`. + `manager()` is used if not supplied. clear_caches : Whether to clear caches before a call of `forward`. Returns