From 4d22b2ab42410efdd7a2209ba3c23c199be6fffe Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Sat, 13 Apr 2024 23:03:48 +0100 Subject: [PATCH] PyTorch interface --- .github/workflows/test-firedrake.yml | 2 + docs/source/conf.py | 1 + tests/base/test_torch.py | 88 +++++++++++++++ tests/fenics/test_base.py | 5 + tests/firedrake/test_base.py | 3 + tests/firedrake/test_torch.py | 92 +++++++++++++++ tlm_adjoint/hessian.py | 15 +-- tlm_adjoint/markers.py | 102 ++++++++++++++++- tlm_adjoint/torch.py | 163 +++++++++++++++++++++++++++ 9 files changed, 458 insertions(+), 13 deletions(-) create mode 100644 tests/base/test_torch.py create mode 100644 tests/firedrake/test_torch.py create mode 100644 tlm_adjoint/torch.py diff --git a/.github/workflows/test-firedrake.yml b/.github/workflows/test-firedrake.yml index 3c1d7102..469f4de8 100644 --- a/.github/workflows/test-firedrake.yml +++ b/.github/workflows/test-firedrake.yml @@ -31,6 +31,7 @@ jobs: run: | . /home/firedrake/firedrake/bin/activate python3 -m pip install jax[cpu] ruff pytest-timeout pytest-xdist + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu - name: Lint run: | . /home/firedrake/firedrake/bin/activate @@ -60,6 +61,7 @@ jobs: run: | . /home/firedrake/firedrake/bin/activate python3 -m pip install jax[cpu] ruff pytest-timeout pytest-xdist + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu - name: Lint run: | . /home/firedrake/firedrake/bin/activate diff --git a/docs/source/conf.py b/docs/source/conf.py index e9f2826c..64959e17 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -40,4 +40,5 @@ "python": ("https://docs.python.org/3", None), "scipy": ("https://docs.scipy.org/doc/scipy", None), "sympy": ("https://docs.sympy.org/latest", None), + "torch": ("https://pytorch.org/docs/stable", None), "ufl": ("https://fenics.readthedocs.io/projects/ufl/en/latest", None)} # noqa: E501 diff --git a/tests/base/test_torch.py b/tests/base/test_torch.py new file mode 100644 index 00000000..21ee5df7 --- /dev/null +++ b/tests/base/test_torch.py @@ -0,0 +1,88 @@ +from tlm_adjoint import DEFAULT_COMM, Float, set_default_float_dtype, var_id +from tlm_adjoint.torch import ( + from_torch_tensors, to_torch_tensors, torch_wrapped) + +from .test_base import setup_test # noqa: F401 + +import numpy as np +import pytest +try: + import torch +except ImportError: + torch = None + +pytestmark = pytest.mark.skipif( + DEFAULT_COMM.size not in {1, 4}, + reason="tests must be run in serial, or with 4 processes") +pytestmark = pytest.mark.skipif( + torch is None, + reason="PyTorch not available") + + +@pytest.mark.base +@pytest.mark.parametrize("dtype", [np.double, np.cdouble]) +def test_torch_tensor_roundtrip(setup_test, # noqa: F811 + dtype): + set_default_float_dtype(dtype) + + if issubclass(dtype, np.complexfloating): + x = Float(-np.sqrt(2.0) + 1.0j * np.sqrt(3.0)) + else: + x = Float(-np.sqrt(2.0)) + y = Float() + from_torch_tensors(y, to_torch_tensors(x)) + assert abs(complex(x) - complex(y)) == 0.0 + + +@pytest.mark.base +@pytest.mark.parametrize("dtype", [np.double, np.cdouble]) +def test_torch_wrapped(setup_test, # noqa: F811 + dtype): + set_default_float_dtype(dtype) + + if issubclass(dtype, np.complexfloating): + m = Float(-np.sqrt(2.0) + 1.0j * np.sqrt(3.0)) + else: + m = Float(-np.sqrt(2.0)) + x = Float() + + def forward(m): + return Float(m) + + _, _, x_t = torch_wrapped(forward, m) + from_torch_tensors(x, x_t) + + assert x is not m + assert var_id(x) != var_id(m) + assert abs(complex(x) - complex(m)) == 0.0 + + +@pytest.mark.base +@pytest.mark.parametrize("dtype", [np.double, np.cdouble]) +@pytest.mark.skipif(DEFAULT_COMM.size > 1, reason="serial only") +def test_torch_vjp(setup_test, # noqa: F811 + dtype): + set_default_float_dtype(dtype) + + if issubclass(dtype, np.complexfloating): + m = Float(-np.sqrt(2.0) + 1.0j * np.sqrt(3.0)) + else: + m = Float(-np.sqrt(2.0)) + J = Float(name="J") + + def forward(m): + return m ** 4 + + J_ref = complex(m) ** 4 + _, forward_t, J_t = torch_wrapped(forward, m) + from_torch_tensors(J, J_t) + assert abs(complex(J) - complex(J_ref)) < 1.0e-15 + + if issubclass(dtype, np.complexfloating): + dm = Float(1.0 - 1.0j) + else: + dm = Float(1.0) + dm_t = to_torch_tensors(dm, requires_grad=True) + + assert torch.autograd.gradcheck(forward_t, dm_t, eps=1.0e-8, + atol=1.0e-8, rtol=1.0e-7) diff --git a/tests/fenics/test_base.py b/tests/fenics/test_base.py index b63b1082..a3e91a30 100644 --- a/tests/fenics/test_base.py +++ b/tests/fenics/test_base.py @@ -5,6 +5,7 @@ backend_Constant, backend_Function, complex_mode) from tlm_adjoint.fenics.interpolation import interpolate_expression from tlm_adjoint.alias import gc_disabled +from tlm_adjoint.markers import AdjointActionMarker from tlm_adjoint.patch import patch_method from ..test_base import chdir_tmp_path, jax_tlm_config, seed_test, tmp_path @@ -148,6 +149,10 @@ def test_leaks(): for tlm_map in manager._tlm_map.values(): del tlm_map._M, tlm_map._dM manager._adj_cache.clear() + for block in list(manager._blocks) + [manager._block]: + for eq in block: + if isinstance(eq, AdjointActionMarker): + del eq._adj_X gc.collect() garbage_cleanup(DEFAULT_COMM) diff --git a/tests/firedrake/test_base.py b/tests/firedrake/test_base.py index 2601b080..6a6bd258 100644 --- a/tests/firedrake/test_base.py +++ b/tests/firedrake/test_base.py @@ -5,6 +5,7 @@ backend_Cofunction, backend_Constant, backend_Function, complex_mode) from tlm_adjoint.firedrake.interpolation import interpolate_expression from tlm_adjoint.alias import gc_disabled +from tlm_adjoint.markers import AdjointActionMarker from tlm_adjoint.patch import patch_method from ..test_base import chdir_tmp_path, jax_tlm_config, seed_test, tmp_path @@ -155,6 +156,8 @@ def test_leaks(): for eq in block: if isinstance(eq, PointInterpolation): del eq._interp + elif isinstance(eq, AdjointActionMarker): + del eq._adj_X gc.collect() garbage_cleanup(DEFAULT_COMM) diff --git a/tests/firedrake/test_torch.py b/tests/firedrake/test_torch.py new file mode 100644 index 00000000..5b4bd4b3 --- /dev/null +++ b/tests/firedrake/test_torch.py @@ -0,0 +1,92 @@ +from firedrake import * +from tlm_adjoint.firedrake import * +from tlm_adjoint.torch import ( + from_torch_tensors, to_torch_tensors, torch_wrapped) + +from .test_base import * + +import pytest +try: + import torch +except ImportError: + torch = None + +pytestmark = pytest.mark.skipif( + DEFAULT_COMM.size not in {1, 4}, + reason="tests must be run in serial, or with 4 processes") +pytestmark = pytest.mark.skipif( + torch is None, + reason="PyTorch not available") + + +@pytest.mark.firedrake +@seed_test +def test_torch_tensor_roundtrip(setup_test, test_leaks): + mesh = UnitSquareMesh(10, 10) + X = SpatialCoordinate(mesh) + space1 = FunctionSpace(mesh, "Lagrange", 1) + space2 = FunctionSpace(mesh, "Lagrange", 2) + + u = Function(space1).interpolate(exp(X[0])) + v = Function(space2).interpolate(sin(pi * X[1])) + c = Constant(sqrt(2.0)) + + for x in u, v, c: + y = var_new(x) + from_torch_tensors(y, to_torch_tensors(x)) + + err = var_copy(x) + var_axpy(err, -1.0, y) + assert var_linf_norm(err) == 0.0 + + +@pytest.mark.firedrake +@seed_test +def test_torch_wrapped(setup_test, test_leaks): + mesh = UnitSquareMesh(10, 10) + X = SpatialCoordinate(mesh) + space = FunctionSpace(mesh, "Lagrange", 1) + + m = Function(space).interpolate(X[0]) + x = Function(space) + + def forward(m): + return m.copy(deepcopy=True) + + _, _, x_t = torch_wrapped(forward, m) + from_torch_tensors(x, x_t) + + err = var_copy(x) + var_axpy(err, -1.0, m) + assert var_linf_norm(err) == 0.0 + + +@pytest.mark.firedrake +@pytest.mark.skipif(DEFAULT_COMM.size > 1, reason="serial only") +@seed_test +def test_torch_vjp(setup_test, test_leaks): + mesh = UnitSquareMesh(10, 10) + X = SpatialCoordinate(mesh) + space = FunctionSpace(mesh, "Lagrange", 1) + + if complex_mode: + m = Function(space).interpolate(X[0] + 1.0j * X[1]) + else: + m = Function(space).interpolate(X[0]) + J = Float(name="J") + + def forward(m): + J = Functional(name="J") + J.assign((m ** 4) * dx) + return J + + J_ref = assemble((m ** 4) * dx) + _, forward_t, J_t = torch_wrapped(forward, m) + from_torch_tensors(J, J_t) + assert abs(complex(J) - complex(J_ref)) == 0.0 + + dm = Function(space, name="dm").interpolate(Constant(1.0)) + dm_t = to_torch_tensors(dm, requires_grad=True) + + assert torch.autograd.gradcheck(forward_t, dm_t, eps=1.0e-8, + atol=1.0e-8, rtol=1.0e-7) diff --git a/tlm_adjoint/hessian.py b/tlm_adjoint/hessian.py index 9bb53925..7621d5a5 100644 --- a/tlm_adjoint/hessian.py +++ b/tlm_adjoint/hessian.py @@ -1,15 +1,15 @@ from .interface import ( Packed, check_space_types_conjugate_dual, packed, var_axpy, var_copy, var_copy_conjugate, var_is_cached, var_is_static, var_locked, var_name, - var_new, var_scalar_value) + var_scalar_value) from .caches import local_caches -from .equations import InnerProduct from .functional import Functional +from .markers import AdjointActionMarker from .manager import manager as _manager from .manager import ( - compute_gradient, configure_tlm, var_tlm, paused_manager, - reset_manager, restore_manager, set_manager, start_manager, stop_manager) + compute_gradient, configure_tlm, var_tlm, reset_manager, restore_manager, + set_manager, start_manager, stop_manager) from abc import ABC, abstractmethod import warnings @@ -252,12 +252,7 @@ def action(self, M, dM, M0=None): # J^T action start_manager() J = Functional() - assert len(X) == len(R_inv_tau_X) - for x, R_inv_tau_x in zip(X, R_inv_tau_X): - J_term = var_new(J) - with paused_manager(annotate=False, tlm=True): - InnerProduct(J_term, x, var_copy(R_inv_tau_x)).solve() - J.addto(J_term) + AdjointActionMarker(J, X, tuple(map(var_copy, R_inv_tau_X))).solve() stop_manager() # Likelihood term: conj[ J^T R^{-1} J dM ] diff --git a/tlm_adjoint/markers.py b/tlm_adjoint/markers.py index ca230df5..10cce824 100644 --- a/tlm_adjoint/markers.py +++ b/tlm_adjoint/markers.py @@ -1,11 +1,14 @@ -from .interface import Packed, var_new +from .interface import ( + Packed, check_space_types_conjugate_dual, is_var, var_assign, var_inner, + var_is_scalar, var_new, var_scalar_value) -from .equation import Equation +from .equation import Equation, ZeroAssignment __all__ = \ [ "ControlsMarker", - "FunctionalMarker" + "FunctionalMarker", + "AdjointActionMarker" ] @@ -62,6 +65,9 @@ class FunctionalMarker(Equation): """ def __init__(self, J): + if not var_is_scalar(J): + raise ValueError("Functional must be a scalar variable") + # Extra variable allocation could be avoided J_ = var_new(J) super().__init__(J_, [J_, J], nl_deps=[], ic=False, adj_ic=False) @@ -74,3 +80,93 @@ def adjoint_derivative_action(self, nl_deps, dep_index, adj_x): def adjoint_jacobian_solve(self, adj_x, nl_deps, b): return b + + +class AdjointActionMarker(Equation): + r"""Represents + + .. math:: + + J_\text{output} = \lambda_x^* x, + + with forward residual + + .. math:: + + \mathcal{F} \left( J_\text{output}, x \right) + = J_\text{output} - \lambda_x^* x. + + Note that :math:`\lambda_x` is *not* treated as a dependency. + + Can be used to initialize an adjoint calculation, and compute adjoint + Jacobian actions, via the construction + + .. code-block:: python + + start_manager() + X = forward(M) + adj_X = ... + J = Float(name="J") + AdjointRHSMarker(J, X, adj_X).solve() + stop_manager() + + # Compute the action of the adjoint of the Jacobian on the direction + # defined by adj_X + dJ = compute_gradient(J, M) + + :arg J: A variable defining the functional :math:`J`. + :arg X: A variable or :class:`Sequence` of variables defining :math:`x`. + :arg adj_X: A variable or :class:`Sequence` of variables defining + :math:`\lambda_x`. + """ + + def __init__(self, J, X, adj_X): + if not var_is_scalar(J): + raise ValueError("Functional must be a scalar variable") + if is_var(X): + X = (X,) + if is_var(adj_X): + adj_X = (adj_X,) + if len(X) != len(adj_X): + raise ValueError("Invalid length") + for x, adj_x in zip(X, adj_X): + check_space_types_conjugate_dual(x, adj_x) + + super().__init__(J, [J] + list(X), nl_deps=X, ic=False, adj_ic=False) + self._adj_X = tuple(adj_X) + + def forward_solve(self, x, deps=None): + J = x + X = (self.dependencies() if deps is None else deps)[1:] + + v = 0.0 + assert len(X) == len(self._adj_X) + for x, adj_x in zip(X, self._adj_X): + v += var_inner(x, adj_x) + var_assign(J, v) + + def adjoint_derivative_action(self, nl_deps, dep_index, adj_x): + if dep_index == 0: + raise ValueError("Unexpected dep_index") + return (-var_scalar_value(adj_x), self._adj_X[dep_index - 1]) + + def adjoint_jacobian_solve(self, adj_x, nl_deps, b): + return b + + def tangent_linear(self, tlm_map): + J = self.x() + X = self.dependencies()[1:] + + tau_X = [] + adj_X = [] + assert len(X) == len(self._adj_X) + for x, adj_x in zip(X, self._adj_X): + tau_x = tlm_map[x] + if tau_x is not None: + tau_X.append(tau_x) + adj_X.append(adj_x) + + if len(tau_X) == 0: + return ZeroAssignment(tlm_map[J]) + else: + return AdjointActionMarker(tlm_map[J], tau_X, adj_X) diff --git a/tlm_adjoint/torch.py b/tlm_adjoint/torch.py new file mode 100644 index 00000000..1c672d5a --- /dev/null +++ b/tlm_adjoint/torch.py @@ -0,0 +1,163 @@ +"""Interface with PyTorch. + +Can be used to embed models, differentiated by tlm_adjoint, within a PyTorch +calculation. Follows the same principles as described in + + - Nacime Bouziani and David A. Ham, 'Physics-driven machine learning models + coupling PyTorch and Firedrake', 2023, arXiv:2303.06871v3 +""" + +from .caches import clear_caches +from .interface import ( + is_var, var_comm, var_dtype, var_get_values, var_id, var_new, + var_new_conjugate_dual, var_set_values) +from .manager import ( + compute_gradient, manager as _manager, reset_manager, restore_manager, + set_manager, start_manager, stop_manager) +from .markers import AdjointActionMarker +from .overloaded_float import Float + +try: + import torch +except ImportError: + torch = None + +__all__ = \ + [ + "to_torch_tensors", + "from_torch_tensors", + "torch_wrapped" + ] + + +def to_torch_tensor(x, *args, **kwargs): + return torch.tensor(var_get_values(x), *args, **kwargs) + + +def to_torch_tensors(X, *args, **kwargs): + """Convert one or more variables to :class:`torch.Tensor` objects. + + :arg X: A variable or :class:`Sequence` or variables. + :returns: A :class:`torch.Tensor` or :class:`tuple` of + :class:`torch.Tensor` objects. + + Remaining arguments are passed to :func:`torch.tensor`. + """ + + if is_var(X): + return (to_torch_tensor(X, *args, **kwargs),) + else: + return tuple(to_torch_tensor(x, *args, **kwargs) for x in X) + + +def from_torch_tensor(x, x_t): + var_set_values(x, x_t.detach().numpy()) + return x + + +def from_torch_tensors(X, X_t): + """Copy data from PyTorch tensors into variables. + + :arg X: A variable or :class:`Sequence` or variables. + :arg X_t: A :class:`torch.Tensor` or :class:`Sequence` of + :class:`torch.Tensor` objects. + """ + + if is_var(X): + X_t, = X_t + from_torch_tensor(X, X_t) + else: + if len(X) != len(X_t): + raise ValueError("Invalid length") + for x, x_t in zip(X, X_t): + from_torch_tensor(x, x_t) + + +@restore_manager +def _forward(forward, M, manager): + set_manager(manager) + reset_manager() + clear_caches() + + start_manager() + X = forward(*M) + if is_var(X): + n = None + X = (X,) + else: + n = len(X) + if n == 0: + raise RuntimeError("forward must return at least one variable") + J = Float(dtype=var_dtype(X[0]), comm=var_comm(X[0])) + adj_X = tuple(map(var_new_conjugate_dual, X)) + AdjointActionMarker(J, X, adj_X).solve() + stop_manager() + + if n is None: + x, = X + adj_x, = adj_X + return x, J, adj_x + else: + return X, J, adj_X + + +class TorchInterface(object if torch is None else torch.autograd.Function): + @staticmethod + def forward(ctx, forward, manager, J_id, M, *M_t): + M = tuple(map(var_new, M)) + from_torch_tensors(M, M_t) + + X, J, adj_X = _forward(forward, M, manager) + + J_id[0] = var_id(J) + ctx._tlm_adjoint__output_ctx = (forward, manager, J_id, M, J, adj_X) + return to_torch_tensors(X) + + @staticmethod + @restore_manager + def backward(ctx, *adj_X_t): + forward, manager, J_id, M, J, adj_X = ctx._tlm_adjoint__output_ctx + if var_id(J) != J_id[0] or manager._cp_schedule.is_exhausted: + _, J, adj_X = _forward(forward, M, manager) + J_id[0] = var_id(J) + + from_torch_tensors(adj_X, adj_X_t) + set_manager(manager) + dJ = compute_gradient(J, M) + + return (None, None, None, None) + to_torch_tensors(dJ) + + +def torch_wrapped(forward, M, *, manager=None): + """Wrap a model, differentiated using tlm_adjoint, so that it can be used + with PyTorch. + + :arg forward: A callable which accepts one or more variable arguments, and + returns a variable or :class:`Sequence` of variables. + :arg M: A variable or :class:`Sequence` of variables defining the input to + `forward`. + :arg manager: An :class:`.EquationManager` used to create an internal + manager via :meth:`.EquationManager.new`. `manager()` is used if not + supplied. + :returns: A :class:`tuple` `(M_t, forward_t, X_t)`, where + + - `M_t` is a :class:`torch.Tensor` storing the value of `M`. + - `forward_t` is a version of `forward` with :class:`torch.Tensor` + inputs and outputs. + - `X_t` is a :class:`torch.Tensor` containing the value of + `forward` evaluated with `M` as input. + """ + + if is_var(M): + M = (M,) + if manager is None: + manager = _manager() + manager = manager.new() + J_id = [None] + + M_t = to_torch_tensors(M, requires_grad=True) + + def forward_t(*M_t): + return TorchInterface.apply(forward, manager, J_id, M, *M_t) + + return M_t, forward_t, forward_t(*M_t)