Skip to content

Commit

Permalink
Integrate checkpoint_schedules library
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmaddison committed Feb 15, 2024
1 parent 4518fed commit d374b53
Show file tree
Hide file tree
Showing 15 changed files with 210 additions and 967 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ jobs:
sudo apt-get update
sudo apt-get install flake8 python3-h5py python3-numpy python3-pytest python3-pytest-timeout python3-pytest-xdist python3-scipy python3-sympy
python3 -m pip install ruff
git clone --depth 1 https://github.com/firedrakeproject/checkpoint_schedules.git
- name: Lint
run: |
cd tlm_adjoint
flake8
ruff check
- name: Run tests
run: |
export PYTHONPATH=$PWD/checkpoint_schedules:$PYTHONPATH
cd tlm_adjoint
pytest-3 -v -n 2 --timeout=300 --timeout-method=thread
2 changes: 2 additions & 0 deletions .github/workflows/test-fenics.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ jobs:
sudo apt-get update
sudo apt-get install flake8 python3-dolfin python3-h5py python3-numpy python3-pytest python3-pytest-timeout python3-pytest-xdist python3-scipy python3-sympy
python3 -m pip install ruff
git clone --depth 1 https://github.com/firedrakeproject/checkpoint_schedules.git
- name: Lint
run: |
cd tlm_adjoint
flake8
ruff check
- name: Run tests
run: |
export PYTHONPATH=$PWD/checkpoint_schedules:$PYTHONPATH
cd tlm_adjoint
pytest-3 -v tests/base tests/checkpoint_schedules tests/fenics -n 2 --timeout=300 --timeout-method=thread
4 changes: 4 additions & 0 deletions .github/workflows/test-firedrake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
run: |
. /home/firedrake/firedrake/bin/activate
python3 -m pip install jax jaxlib ruff pytest-timeout pytest-xdist
git clone --depth 1 https://github.com/firedrakeproject/checkpoint_schedules.git
- name: Lint
run: |
. /home/firedrake/firedrake/bin/activate
Expand All @@ -40,6 +41,7 @@ jobs:
- name: Run tests
run: |
. /home/firedrake/firedrake/bin/activate
export PYTHONPATH=$PWD/checkpoint_schedules:$PYTHONPATH
cd tlm_adjoint
pytest -v tests/base tests/checkpoint_schedules tests/firedrake -n 2 --timeout=300 --timeout-method=thread
test-complex:
Expand All @@ -60,6 +62,7 @@ jobs:
run: |
. /home/firedrake/firedrake/bin/activate
python3 -m pip install jax jaxlib ruff pytest-timeout pytest-xdist
git clone --depth 1 https://github.com/firedrakeproject/checkpoint_schedules.git
- name: Lint
run: |
. /home/firedrake/firedrake/bin/activate
Expand All @@ -69,5 +72,6 @@ jobs:
- name: Run tests
run: |
. /home/firedrake/firedrake/bin/activate
export PYTHONPATH=$PWD/checkpoint_schedules:$PYTHONPATH
cd tlm_adjoint
pytest -v tests/base tests/checkpoint_schedules tests/firedrake -n 2 --timeout=300 --timeout-method=thread
1 change: 1 addition & 0 deletions docs/source/dependencies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ tlm_adjoint requires:

- `NumPy <https://numpy.org/>`_
- `SymPy <https://www.sympy.org>`_
- `checkpoint_schedules <https://www.firedrakeproject.org/checkpoint_schedules/>`_

Backend dependencies
--------------------
Expand Down
9 changes: 1 addition & 8 deletions tests/checkpoint_schedules/test_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
import functools
import pytest

try:
import hrevolve
except ImportError:
hrevolve = None
try:
import mpi4py.MPI as MPI
except ImportError:
Expand Down Expand Up @@ -72,10 +68,7 @@ def mixed(n, s):
(two_level, {"period": 2}),
(two_level, {"period": 7}),
(two_level, {"period": 10}),
pytest.param(
h_revolve, {},
marks=pytest.mark.skipif(hrevolve is None,
reason="H-Revolve not available")),
(h_revolve, {}),
(mixed, {})])
@pytest.mark.parametrize("n, S", [(1, (0,)),
(2, (1,)),
Expand Down
8 changes: 3 additions & 5 deletions tests/fenics/test_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,10 +792,8 @@ def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
x_0, _, adj_x_0, J = forward(y)
stop_manager()

assert len(manager()._cp._refs) == 3
assert tuple(manager()._cp._refs.keys()) == (var_id(y),
var_id(zero),
var_id(adj_x_0))
assert len(manager()._cp._refs) == 1
assert tuple(manager()._cp._refs.keys()) == (var_id(adj_x_0),)
assert len(manager()._cp._cp) == 0
if test_adj_ic:
assert len(manager()._cp._data) == 8
Expand All @@ -805,7 +803,7 @@ def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
assert len(manager()._cp._data) == 9
assert tuple(map(len, manager()._cp._data.values())) \
== (0, 0, 0, 1, 0, 0, 2, 0, 0)
assert len(manager()._cp._storage) == 5
assert len(manager()._cp._storage) == 3

dJdx_0, dJdy = compute_gradient(J, [x_0, y])
assert var_linf_norm(dJdx_0) == 0.0
Expand Down
10 changes: 1 addition & 9 deletions tests/fenics/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@
import numpy as np
import pytest

try:
import hrevolve
except ImportError:
hrevolve = None

pytestmark = pytest.mark.skipif(
DEFAULT_COMM.size not in {1, 4},
reason="tests must be run in serial, or with 4 processes")
Expand Down Expand Up @@ -86,10 +81,7 @@ def diffusion_ref():
"snaps_in_ram": 2}),
("multistage", {"format": "hdf5", "snaps_on_disk": 1,
"snaps_in_ram": 2}),
pytest.param(
"H-Revolve", {"snapshots_on_disk": 1, "snapshots_in_ram": 2},
marks=pytest.mark.skipif(hrevolve is None,
reason="H-Revolve not available")),
("H-Revolve", {"snapshots_on_disk": 1, "snapshots_in_ram": 2}),
("mixed", {"snapshots": 2, "storage": "disk"})])
@seed_test
def test_oscillator(setup_test, test_leaks,
Expand Down
8 changes: 3 additions & 5 deletions tests/firedrake/test_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,10 +1043,8 @@ def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
x_0, _, adj_x_0, J = forward(y)
stop_manager()

assert len(manager()._cp._refs) == 3
assert tuple(manager()._cp._refs.keys()) == (var_id(y),
var_id(zero),
var_id(adj_x_0))
assert len(manager()._cp._refs) == 1
assert tuple(manager()._cp._refs.keys()) == (var_id(adj_x_0),)
assert len(manager()._cp._cp) == 0
if test_adj_ic:
assert len(manager()._cp._data) == 9
Expand All @@ -1056,7 +1054,7 @@ def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
assert len(manager()._cp._data) == 10
assert tuple(map(len, manager()._cp._data.values())) \
== (0, 0, 0, 0, 1, 0, 0, 2, 0, 0)
assert len(manager()._cp._storage) == 5
assert len(manager()._cp._storage) == 3

dJdx_0, dJdy = compute_gradient(J, [x_0, y])
assert var_linf_norm(dJdx_0) == 0.0
Expand Down
10 changes: 1 addition & 9 deletions tests/firedrake/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@
import numpy as np
import pytest

try:
import hrevolve
except ImportError:
hrevolve = None

pytestmark = pytest.mark.skipif(
DEFAULT_COMM.size not in {1, 4},
reason="tests must be run in serial, or with 4 processes")
Expand Down Expand Up @@ -84,10 +79,7 @@ def diffusion_ref():
"snaps_in_ram": 2}),
("multistage", {"format": "hdf5", "snaps_on_disk": 1,
"snaps_in_ram": 2}),
pytest.param(
"H-Revolve", {"snapshots_on_disk": 1, "snapshots_in_ram": 2},
marks=pytest.mark.skipif(hrevolve is None,
reason="H-Revolve not available")),
("H-Revolve", {"snapshots_on_disk": 1, "snapshots_in_ram": 2}),
("mixed", {"snapshots": 2, "storage": "disk"})])
@seed_test
def test_oscillator(setup_test, test_leaks,
Expand Down
Loading

0 comments on commit d374b53

Please sign in to comment.