-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cc2bc1e
commit 4d22b2a
Showing
9 changed files
with
458 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from tlm_adjoint import DEFAULT_COMM, Float, set_default_float_dtype, var_id | ||
from tlm_adjoint.torch import ( | ||
from_torch_tensors, to_torch_tensors, torch_wrapped) | ||
|
||
from .test_base import setup_test # noqa: F401 | ||
|
||
import numpy as np | ||
import pytest | ||
try: | ||
import torch | ||
except ImportError: | ||
torch = None | ||
|
||
pytestmark = pytest.mark.skipif( | ||
DEFAULT_COMM.size not in {1, 4}, | ||
reason="tests must be run in serial, or with 4 processes") | ||
pytestmark = pytest.mark.skipif( | ||
torch is None, | ||
reason="PyTorch not available") | ||
|
||
|
||
@pytest.mark.base | ||
@pytest.mark.parametrize("dtype", [np.double, np.cdouble]) | ||
def test_torch_tensor_roundtrip(setup_test, # noqa: F811 | ||
dtype): | ||
set_default_float_dtype(dtype) | ||
|
||
if issubclass(dtype, np.complexfloating): | ||
x = Float(-np.sqrt(2.0) + 1.0j * np.sqrt(3.0)) | ||
else: | ||
x = Float(-np.sqrt(2.0)) | ||
y = Float() | ||
from_torch_tensors(y, to_torch_tensors(x)) | ||
assert abs(complex(x) - complex(y)) == 0.0 | ||
|
||
|
||
@pytest.mark.base | ||
@pytest.mark.parametrize("dtype", [np.double, np.cdouble]) | ||
def test_torch_wrapped(setup_test, # noqa: F811 | ||
dtype): | ||
set_default_float_dtype(dtype) | ||
|
||
if issubclass(dtype, np.complexfloating): | ||
m = Float(-np.sqrt(2.0) + 1.0j * np.sqrt(3.0)) | ||
else: | ||
m = Float(-np.sqrt(2.0)) | ||
x = Float() | ||
|
||
def forward(m): | ||
return Float(m) | ||
|
||
_, _, x_t = torch_wrapped(forward, m) | ||
from_torch_tensors(x, x_t) | ||
|
||
assert x is not m | ||
assert var_id(x) != var_id(m) | ||
assert abs(complex(x) - complex(m)) == 0.0 | ||
|
||
|
||
@pytest.mark.base | ||
@pytest.mark.parametrize("dtype", [np.double, np.cdouble]) | ||
@pytest.mark.skipif(DEFAULT_COMM.size > 1, reason="serial only") | ||
def test_torch_vjp(setup_test, # noqa: F811 | ||
dtype): | ||
set_default_float_dtype(dtype) | ||
|
||
if issubclass(dtype, np.complexfloating): | ||
m = Float(-np.sqrt(2.0) + 1.0j * np.sqrt(3.0)) | ||
else: | ||
m = Float(-np.sqrt(2.0)) | ||
J = Float(name="J") | ||
|
||
def forward(m): | ||
return m ** 4 | ||
|
||
J_ref = complex(m) ** 4 | ||
_, forward_t, J_t = torch_wrapped(forward, m) | ||
from_torch_tensors(J, J_t) | ||
assert abs(complex(J) - complex(J_ref)) < 1.0e-15 | ||
|
||
if issubclass(dtype, np.complexfloating): | ||
dm = Float(1.0 - 1.0j) | ||
else: | ||
dm = Float(1.0) | ||
dm_t = to_torch_tensors(dm, requires_grad=True) | ||
|
||
assert torch.autograd.gradcheck(forward_t, dm_t, eps=1.0e-8, | ||
atol=1.0e-8, rtol=1.0e-7) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from firedrake import * | ||
from tlm_adjoint.firedrake import * | ||
from tlm_adjoint.torch import ( | ||
from_torch_tensors, to_torch_tensors, torch_wrapped) | ||
|
||
from .test_base import * | ||
|
||
import pytest | ||
try: | ||
import torch | ||
except ImportError: | ||
torch = None | ||
|
||
pytestmark = pytest.mark.skipif( | ||
DEFAULT_COMM.size not in {1, 4}, | ||
reason="tests must be run in serial, or with 4 processes") | ||
pytestmark = pytest.mark.skipif( | ||
torch is None, | ||
reason="PyTorch not available") | ||
|
||
|
||
@pytest.mark.firedrake | ||
@seed_test | ||
def test_torch_tensor_roundtrip(setup_test, test_leaks): | ||
mesh = UnitSquareMesh(10, 10) | ||
X = SpatialCoordinate(mesh) | ||
space1 = FunctionSpace(mesh, "Lagrange", 1) | ||
space2 = FunctionSpace(mesh, "Lagrange", 2) | ||
|
||
u = Function(space1).interpolate(exp(X[0])) | ||
v = Function(space2).interpolate(sin(pi * X[1])) | ||
c = Constant(sqrt(2.0)) | ||
|
||
for x in u, v, c: | ||
y = var_new(x) | ||
from_torch_tensors(y, to_torch_tensors(x)) | ||
|
||
err = var_copy(x) | ||
var_axpy(err, -1.0, y) | ||
assert var_linf_norm(err) == 0.0 | ||
|
||
|
||
@pytest.mark.firedrake | ||
@seed_test | ||
def test_torch_wrapped(setup_test, test_leaks): | ||
mesh = UnitSquareMesh(10, 10) | ||
X = SpatialCoordinate(mesh) | ||
space = FunctionSpace(mesh, "Lagrange", 1) | ||
|
||
m = Function(space).interpolate(X[0]) | ||
x = Function(space) | ||
|
||
def forward(m): | ||
return m.copy(deepcopy=True) | ||
|
||
_, _, x_t = torch_wrapped(forward, m) | ||
from_torch_tensors(x, x_t) | ||
|
||
err = var_copy(x) | ||
var_axpy(err, -1.0, m) | ||
assert var_linf_norm(err) == 0.0 | ||
|
||
|
||
@pytest.mark.firedrake | ||
@pytest.mark.skipif(DEFAULT_COMM.size > 1, reason="serial only") | ||
@seed_test | ||
def test_torch_vjp(setup_test, test_leaks): | ||
mesh = UnitSquareMesh(10, 10) | ||
X = SpatialCoordinate(mesh) | ||
space = FunctionSpace(mesh, "Lagrange", 1) | ||
|
||
if complex_mode: | ||
m = Function(space).interpolate(X[0] + 1.0j * X[1]) | ||
else: | ||
m = Function(space).interpolate(X[0]) | ||
J = Float(name="J") | ||
|
||
def forward(m): | ||
J = Functional(name="J") | ||
J.assign((m ** 4) * dx) | ||
return J | ||
|
||
J_ref = assemble((m ** 4) * dx) | ||
_, forward_t, J_t = torch_wrapped(forward, m) | ||
from_torch_tensors(J, J_t) | ||
assert abs(complex(J) - complex(J_ref)) == 0.0 | ||
|
||
dm = Function(space, name="dm").interpolate(Constant(1.0)) | ||
dm_t = to_torch_tensors(dm, requires_grad=True) | ||
|
||
assert torch.autograd.gradcheck(forward_t, dm_t, eps=1.0e-8, | ||
atol=1.0e-8, rtol=1.0e-7) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.