Skip to content

Commit

Permalink
PyTorch interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmaddison committed May 29, 2024
1 parent cc2bc1e commit 4d22b2a
Show file tree
Hide file tree
Showing 9 changed files with 458 additions and 13 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test-firedrake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
run: |
. /home/firedrake/firedrake/bin/activate
python3 -m pip install jax[cpu] ruff pytest-timeout pytest-xdist
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Lint
run: |
. /home/firedrake/firedrake/bin/activate
Expand Down Expand Up @@ -60,6 +61,7 @@ jobs:
run: |
. /home/firedrake/firedrake/bin/activate
python3 -m pip install jax[cpu] ruff pytest-timeout pytest-xdist
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Lint
run: |
. /home/firedrake/firedrake/bin/activate
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@
"python": ("https://docs.python.org/3", None),
"scipy": ("https://docs.scipy.org/doc/scipy", None),
"sympy": ("https://docs.sympy.org/latest", None),
"torch": ("https://pytorch.org/docs/stable", None),
"ufl": ("https://fenics.readthedocs.io/projects/ufl/en/latest", None)} # noqa: E501
88 changes: 88 additions & 0 deletions tests/base/test_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from tlm_adjoint import DEFAULT_COMM, Float, set_default_float_dtype, var_id
from tlm_adjoint.torch import (
from_torch_tensors, to_torch_tensors, torch_wrapped)

from .test_base import setup_test # noqa: F401

import numpy as np
import pytest
try:
import torch
except ImportError:
torch = None

pytestmark = pytest.mark.skipif(
DEFAULT_COMM.size not in {1, 4},
reason="tests must be run in serial, or with 4 processes")
pytestmark = pytest.mark.skipif(
torch is None,
reason="PyTorch not available")


@pytest.mark.base
@pytest.mark.parametrize("dtype", [np.double, np.cdouble])
def test_torch_tensor_roundtrip(setup_test, # noqa: F811
dtype):
set_default_float_dtype(dtype)

if issubclass(dtype, np.complexfloating):
x = Float(-np.sqrt(2.0) + 1.0j * np.sqrt(3.0))
else:
x = Float(-np.sqrt(2.0))
y = Float()
from_torch_tensors(y, to_torch_tensors(x))
assert abs(complex(x) - complex(y)) == 0.0


@pytest.mark.base
@pytest.mark.parametrize("dtype", [np.double, np.cdouble])
def test_torch_wrapped(setup_test, # noqa: F811
dtype):
set_default_float_dtype(dtype)

if issubclass(dtype, np.complexfloating):
m = Float(-np.sqrt(2.0) + 1.0j * np.sqrt(3.0))
else:
m = Float(-np.sqrt(2.0))
x = Float()

def forward(m):
return Float(m)

_, _, x_t = torch_wrapped(forward, m)
from_torch_tensors(x, x_t)

assert x is not m
assert var_id(x) != var_id(m)
assert abs(complex(x) - complex(m)) == 0.0


@pytest.mark.base
@pytest.mark.parametrize("dtype", [np.double, np.cdouble])
@pytest.mark.skipif(DEFAULT_COMM.size > 1, reason="serial only")
def test_torch_vjp(setup_test, # noqa: F811
dtype):
set_default_float_dtype(dtype)

if issubclass(dtype, np.complexfloating):
m = Float(-np.sqrt(2.0) + 1.0j * np.sqrt(3.0))
else:
m = Float(-np.sqrt(2.0))
J = Float(name="J")

def forward(m):
return m ** 4

J_ref = complex(m) ** 4
_, forward_t, J_t = torch_wrapped(forward, m)
from_torch_tensors(J, J_t)
assert abs(complex(J) - complex(J_ref)) < 1.0e-15

if issubclass(dtype, np.complexfloating):
dm = Float(1.0 - 1.0j)
else:
dm = Float(1.0)
dm_t = to_torch_tensors(dm, requires_grad=True)

assert torch.autograd.gradcheck(forward_t, dm_t, eps=1.0e-8,
atol=1.0e-8, rtol=1.0e-7)
5 changes: 5 additions & 0 deletions tests/fenics/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
backend_Constant, backend_Function, complex_mode)
from tlm_adjoint.fenics.interpolation import interpolate_expression
from tlm_adjoint.alias import gc_disabled
from tlm_adjoint.markers import AdjointActionMarker
from tlm_adjoint.patch import patch_method

from ..test_base import chdir_tmp_path, jax_tlm_config, seed_test, tmp_path
Expand Down Expand Up @@ -148,6 +149,10 @@ def test_leaks():
for tlm_map in manager._tlm_map.values():
del tlm_map._M, tlm_map._dM
manager._adj_cache.clear()
for block in list(manager._blocks) + [manager._block]:
for eq in block:
if isinstance(eq, AdjointActionMarker):
del eq._adj_X

gc.collect()
garbage_cleanup(DEFAULT_COMM)
Expand Down
3 changes: 3 additions & 0 deletions tests/firedrake/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
backend_Cofunction, backend_Constant, backend_Function, complex_mode)
from tlm_adjoint.firedrake.interpolation import interpolate_expression
from tlm_adjoint.alias import gc_disabled
from tlm_adjoint.markers import AdjointActionMarker
from tlm_adjoint.patch import patch_method

from ..test_base import chdir_tmp_path, jax_tlm_config, seed_test, tmp_path
Expand Down Expand Up @@ -155,6 +156,8 @@ def test_leaks():
for eq in block:
if isinstance(eq, PointInterpolation):
del eq._interp
elif isinstance(eq, AdjointActionMarker):
del eq._adj_X

gc.collect()
garbage_cleanup(DEFAULT_COMM)
Expand Down
92 changes: 92 additions & 0 deletions tests/firedrake/test_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from firedrake import *
from tlm_adjoint.firedrake import *
from tlm_adjoint.torch import (
from_torch_tensors, to_torch_tensors, torch_wrapped)

from .test_base import *

import pytest
try:
import torch
except ImportError:
torch = None

pytestmark = pytest.mark.skipif(
DEFAULT_COMM.size not in {1, 4},
reason="tests must be run in serial, or with 4 processes")
pytestmark = pytest.mark.skipif(
torch is None,
reason="PyTorch not available")


@pytest.mark.firedrake
@seed_test
def test_torch_tensor_roundtrip(setup_test, test_leaks):
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space1 = FunctionSpace(mesh, "Lagrange", 1)
space2 = FunctionSpace(mesh, "Lagrange", 2)

u = Function(space1).interpolate(exp(X[0]))
v = Function(space2).interpolate(sin(pi * X[1]))
c = Constant(sqrt(2.0))

for x in u, v, c:
y = var_new(x)
from_torch_tensors(y, to_torch_tensors(x))

err = var_copy(x)
var_axpy(err, -1.0, y)
assert var_linf_norm(err) == 0.0


@pytest.mark.firedrake
@seed_test
def test_torch_wrapped(setup_test, test_leaks):
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)

m = Function(space).interpolate(X[0])
x = Function(space)

def forward(m):
return m.copy(deepcopy=True)

_, _, x_t = torch_wrapped(forward, m)
from_torch_tensors(x, x_t)

err = var_copy(x)
var_axpy(err, -1.0, m)
assert var_linf_norm(err) == 0.0


@pytest.mark.firedrake
@pytest.mark.skipif(DEFAULT_COMM.size > 1, reason="serial only")
@seed_test
def test_torch_vjp(setup_test, test_leaks):
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)

if complex_mode:
m = Function(space).interpolate(X[0] + 1.0j * X[1])
else:
m = Function(space).interpolate(X[0])
J = Float(name="J")

def forward(m):
J = Functional(name="J")
J.assign((m ** 4) * dx)
return J

J_ref = assemble((m ** 4) * dx)
_, forward_t, J_t = torch_wrapped(forward, m)
from_torch_tensors(J, J_t)
assert abs(complex(J) - complex(J_ref)) == 0.0

dm = Function(space, name="dm").interpolate(Constant(1.0))
dm_t = to_torch_tensors(dm, requires_grad=True)

assert torch.autograd.gradcheck(forward_t, dm_t, eps=1.0e-8,
atol=1.0e-8, rtol=1.0e-7)
15 changes: 5 additions & 10 deletions tlm_adjoint/hessian.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from .interface import (
Packed, check_space_types_conjugate_dual, packed, var_axpy, var_copy,
var_copy_conjugate, var_is_cached, var_is_static, var_locked, var_name,
var_new, var_scalar_value)
var_scalar_value)

from .caches import local_caches
from .equations import InnerProduct
from .functional import Functional
from .markers import AdjointActionMarker
from .manager import manager as _manager
from .manager import (
compute_gradient, configure_tlm, var_tlm, paused_manager,
reset_manager, restore_manager, set_manager, start_manager, stop_manager)
compute_gradient, configure_tlm, var_tlm, reset_manager, restore_manager,
set_manager, start_manager, stop_manager)

from abc import ABC, abstractmethod
import warnings
Expand Down Expand Up @@ -252,12 +252,7 @@ def action(self, M, dM, M0=None):
# J^T action
start_manager()
J = Functional()
assert len(X) == len(R_inv_tau_X)
for x, R_inv_tau_x in zip(X, R_inv_tau_X):
J_term = var_new(J)
with paused_manager(annotate=False, tlm=True):
InnerProduct(J_term, x, var_copy(R_inv_tau_x)).solve()
J.addto(J_term)
AdjointActionMarker(J, X, tuple(map(var_copy, R_inv_tau_X))).solve()
stop_manager()

# Likelihood term: conj[ J^T R^{-1} J dM ]
Expand Down
102 changes: 99 additions & 3 deletions tlm_adjoint/markers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from .interface import Packed, var_new
from .interface import (
Packed, check_space_types_conjugate_dual, is_var, var_assign, var_inner,
var_is_scalar, var_new, var_scalar_value)

from .equation import Equation
from .equation import Equation, ZeroAssignment

__all__ = \
[
"ControlsMarker",
"FunctionalMarker"
"FunctionalMarker",
"AdjointActionMarker"
]


Expand Down Expand Up @@ -62,6 +65,9 @@ class FunctionalMarker(Equation):
"""

def __init__(self, J):
if not var_is_scalar(J):
raise ValueError("Functional must be a scalar variable")

# Extra variable allocation could be avoided
J_ = var_new(J)
super().__init__(J_, [J_, J], nl_deps=[], ic=False, adj_ic=False)
Expand All @@ -74,3 +80,93 @@ def adjoint_derivative_action(self, nl_deps, dep_index, adj_x):

def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
return b


class AdjointActionMarker(Equation):
r"""Represents
.. math::
J_\text{output} = \lambda_x^* x,
with forward residual
.. math::
\mathcal{F} \left( J_\text{output}, x \right)
= J_\text{output} - \lambda_x^* x.
Note that :math:`\lambda_x` is *not* treated as a dependency.
Can be used to initialize an adjoint calculation, and compute adjoint
Jacobian actions, via the construction
.. code-block:: python
start_manager()
X = forward(M)
adj_X = ...
J = Float(name="J")
AdjointRHSMarker(J, X, adj_X).solve()
stop_manager()
# Compute the action of the adjoint of the Jacobian on the direction
# defined by adj_X
dJ = compute_gradient(J, M)
:arg J: A variable defining the functional :math:`J`.
:arg X: A variable or :class:`Sequence` of variables defining :math:`x`.
:arg adj_X: A variable or :class:`Sequence` of variables defining
:math:`\lambda_x`.
"""

def __init__(self, J, X, adj_X):
if not var_is_scalar(J):
raise ValueError("Functional must be a scalar variable")
if is_var(X):
X = (X,)
if is_var(adj_X):
adj_X = (adj_X,)
if len(X) != len(adj_X):
raise ValueError("Invalid length")
for x, adj_x in zip(X, adj_X):
check_space_types_conjugate_dual(x, adj_x)

super().__init__(J, [J] + list(X), nl_deps=X, ic=False, adj_ic=False)
self._adj_X = tuple(adj_X)

def forward_solve(self, x, deps=None):
J = x
X = (self.dependencies() if deps is None else deps)[1:]

v = 0.0
assert len(X) == len(self._adj_X)
for x, adj_x in zip(X, self._adj_X):
v += var_inner(x, adj_x)
var_assign(J, v)

def adjoint_derivative_action(self, nl_deps, dep_index, adj_x):
if dep_index == 0:
raise ValueError("Unexpected dep_index")
return (-var_scalar_value(adj_x), self._adj_X[dep_index - 1])

def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
return b

def tangent_linear(self, tlm_map):
J = self.x()
X = self.dependencies()[1:]

tau_X = []
adj_X = []
assert len(X) == len(self._adj_X)
for x, adj_x in zip(X, self._adj_X):
tau_x = tlm_map[x]
if tau_x is not None:
tau_X.append(tau_x)
adj_X.append(adj_x)

if len(tau_X) == 0:
return ZeroAssignment(tlm_map[J])
else:
return AdjointActionMarker(tlm_map[J], tau_X, adj_X)
Loading

0 comments on commit 4d22b2a

Please sign in to comment.