diff --git a/.github/workflows/test-base.yml b/.github/workflows/test-base.yml index 1fc39f584..268f8d44b 100644 --- a/.github/workflows/test-base.yml +++ b/.github/workflows/test-base.yml @@ -29,6 +29,7 @@ jobs: sudo apt-get update sudo apt-get install flake8 python3-h5py python3-numpy python3-pytest python3-pytest-timeout python3-pytest-xdist python3-scipy python3-sympy python3 -m pip install ruff + git clone --depth 1 https://github.com/firedrakeproject/checkpoint_schedules.git - name: Lint run: | cd tlm_adjoint @@ -36,5 +37,6 @@ jobs: ruff check - name: Run tests run: | + export PYTHONPATH=$PWD/checkpoint_schedules:$PYTHONPATH cd tlm_adjoint pytest-3 -v -n 2 --timeout=300 --timeout-method=thread diff --git a/.github/workflows/test-fenics.yml b/.github/workflows/test-fenics.yml index 905a8a34d..fad5e8e5d 100644 --- a/.github/workflows/test-fenics.yml +++ b/.github/workflows/test-fenics.yml @@ -29,6 +29,7 @@ jobs: sudo apt-get update sudo apt-get install flake8 python3-dolfin python3-h5py python3-numpy python3-pytest python3-pytest-timeout python3-pytest-xdist python3-scipy python3-sympy python3 -m pip install ruff + git clone --depth 1 https://github.com/firedrakeproject/checkpoint_schedules.git - name: Lint run: | cd tlm_adjoint @@ -36,5 +37,6 @@ jobs: ruff check - name: Run tests run: | + export PYTHONPATH=$PWD/checkpoint_schedules:$PYTHONPATH cd tlm_adjoint pytest-3 -v tests/base tests/checkpoint_schedules tests/fenics -n 2 --timeout=300 --timeout-method=thread diff --git a/.github/workflows/test-firedrake.yml b/.github/workflows/test-firedrake.yml index 6a248b5da..8a7a0fed0 100644 --- a/.github/workflows/test-firedrake.yml +++ b/.github/workflows/test-firedrake.yml @@ -31,6 +31,7 @@ jobs: run: | . /home/firedrake/firedrake/bin/activate python3 -m pip install jax jaxlib ruff pytest-timeout pytest-xdist + git clone --depth 1 https://github.com/firedrakeproject/checkpoint_schedules.git - name: Lint run: | . /home/firedrake/firedrake/bin/activate @@ -40,6 +41,7 @@ jobs: - name: Run tests run: | . /home/firedrake/firedrake/bin/activate + export PYTHONPATH=$PWD/checkpoint_schedules:$PYTHONPATH cd tlm_adjoint pytest -v tests/base tests/checkpoint_schedules tests/firedrake -n 2 --timeout=300 --timeout-method=thread test-complex: @@ -60,6 +62,7 @@ jobs: run: | . /home/firedrake/firedrake/bin/activate python3 -m pip install jax jaxlib ruff pytest-timeout pytest-xdist + git clone --depth 1 https://github.com/firedrakeproject/checkpoint_schedules.git - name: Lint run: | . /home/firedrake/firedrake/bin/activate @@ -69,5 +72,6 @@ jobs: - name: Run tests run: | . /home/firedrake/firedrake/bin/activate + export PYTHONPATH=$PWD/checkpoint_schedules:$PYTHONPATH cd tlm_adjoint pytest -v tests/base tests/checkpoint_schedules tests/firedrake -n 2 --timeout=300 --timeout-method=thread diff --git a/docs/source/dependencies.rst b/docs/source/dependencies.rst index d3129a0d8..a2d9f3fa9 100644 --- a/docs/source/dependencies.rst +++ b/docs/source/dependencies.rst @@ -8,6 +8,7 @@ tlm_adjoint requires: - `NumPy `_ - `SymPy `_ + - `checkpoint_schedules `_ Backend dependencies -------------------- diff --git a/tests/checkpoint_schedules/test_validity.py b/tests/checkpoint_schedules/test_validity.py index f4bda458f..2a718562b 100644 --- a/tests/checkpoint_schedules/test_validity.py +++ b/tests/checkpoint_schedules/test_validity.py @@ -11,10 +11,6 @@ import functools import pytest -try: - import hrevolve -except ImportError: - hrevolve = None try: import mpi4py.MPI as MPI except ImportError: @@ -72,10 +68,7 @@ def mixed(n, s): (two_level, {"period": 2}), (two_level, {"period": 7}), (two_level, {"period": 10}), - pytest.param( - h_revolve, {}, - marks=pytest.mark.skipif(hrevolve is None, - reason="H-Revolve not available")), + (h_revolve, {}), (mixed, {})]) @pytest.mark.parametrize("n, S", [(1, (0,)), (2, (1,)), diff --git a/tests/fenics/test_equations.py b/tests/fenics/test_equations.py index 7036ff0dc..74f47564a 100644 --- a/tests/fenics/test_equations.py +++ b/tests/fenics/test_equations.py @@ -792,10 +792,8 @@ def adjoint_jacobian_solve(self, adj_x, nl_deps, b): x_0, _, adj_x_0, J = forward(y) stop_manager() - assert len(manager()._cp._refs) == 3 - assert tuple(manager()._cp._refs.keys()) == (var_id(y), - var_id(zero), - var_id(adj_x_0)) + assert len(manager()._cp._refs) == 1 + assert tuple(manager()._cp._refs.keys()) == (var_id(adj_x_0),) assert len(manager()._cp._cp) == 0 if test_adj_ic: assert len(manager()._cp._data) == 8 @@ -805,7 +803,7 @@ def adjoint_jacobian_solve(self, adj_x, nl_deps, b): assert len(manager()._cp._data) == 9 assert tuple(map(len, manager()._cp._data.values())) \ == (0, 0, 0, 1, 0, 0, 2, 0, 0) - assert len(manager()._cp._storage) == 5 + assert len(manager()._cp._storage) == 3 dJdx_0, dJdy = compute_gradient(J, [x_0, y]) assert var_linf_norm(dJdx_0) == 0.0 diff --git a/tests/fenics/test_models.py b/tests/fenics/test_models.py index 4f0ae06cd..96b49ddb7 100644 --- a/tests/fenics/test_models.py +++ b/tests/fenics/test_models.py @@ -6,11 +6,6 @@ import numpy as np import pytest -try: - import hrevolve -except ImportError: - hrevolve = None - pytestmark = pytest.mark.skipif( DEFAULT_COMM.size not in {1, 4}, reason="tests must be run in serial, or with 4 processes") @@ -86,10 +81,7 @@ def diffusion_ref(): "snaps_in_ram": 2}), ("multistage", {"format": "hdf5", "snaps_on_disk": 1, "snaps_in_ram": 2}), - pytest.param( - "H-Revolve", {"snapshots_on_disk": 1, "snapshots_in_ram": 2}, - marks=pytest.mark.skipif(hrevolve is None, - reason="H-Revolve not available")), + ("H-Revolve", {"snapshots_on_disk": 1, "snapshots_in_ram": 2}), ("mixed", {"snapshots": 2, "storage": "disk"})]) @seed_test def test_oscillator(setup_test, test_leaks, diff --git a/tests/firedrake/test_equations.py b/tests/firedrake/test_equations.py index 88b146824..9484200fd 100644 --- a/tests/firedrake/test_equations.py +++ b/tests/firedrake/test_equations.py @@ -1043,10 +1043,8 @@ def adjoint_jacobian_solve(self, adj_x, nl_deps, b): x_0, _, adj_x_0, J = forward(y) stop_manager() - assert len(manager()._cp._refs) == 3 - assert tuple(manager()._cp._refs.keys()) == (var_id(y), - var_id(zero), - var_id(adj_x_0)) + assert len(manager()._cp._refs) == 1 + assert tuple(manager()._cp._refs.keys()) == (var_id(adj_x_0),) assert len(manager()._cp._cp) == 0 if test_adj_ic: assert len(manager()._cp._data) == 9 @@ -1056,7 +1054,7 @@ def adjoint_jacobian_solve(self, adj_x, nl_deps, b): assert len(manager()._cp._data) == 10 assert tuple(map(len, manager()._cp._data.values())) \ == (0, 0, 0, 0, 1, 0, 0, 2, 0, 0) - assert len(manager()._cp._storage) == 5 + assert len(manager()._cp._storage) == 3 dJdx_0, dJdy = compute_gradient(J, [x_0, y]) assert var_linf_norm(dJdx_0) == 0.0 diff --git a/tests/firedrake/test_models.py b/tests/firedrake/test_models.py index 0db70effa..b97743680 100644 --- a/tests/firedrake/test_models.py +++ b/tests/firedrake/test_models.py @@ -6,11 +6,6 @@ import numpy as np import pytest -try: - import hrevolve -except ImportError: - hrevolve = None - pytestmark = pytest.mark.skipif( DEFAULT_COMM.size not in {1, 4}, reason="tests must be run in serial, or with 4 processes") @@ -84,10 +79,7 @@ def diffusion_ref(): "snaps_in_ram": 2}), ("multistage", {"format": "hdf5", "snaps_on_disk": 1, "snaps_in_ram": 2}), - pytest.param( - "H-Revolve", {"snapshots_on_disk": 1, "snapshots_in_ram": 2}, - marks=pytest.mark.skipif(hrevolve is None, - reason="H-Revolve not available")), + ("H-Revolve", {"snapshots_on_disk": 1, "snapshots_in_ram": 2}), ("mixed", {"snapshots": 2, "storage": "disk"})]) @seed_test def test_oscillator(setup_test, test_leaks, diff --git a/tlm_adjoint/checkpoint_schedules/binomial.py b/tlm_adjoint/checkpoint_schedules/binomial.py index 9907a03b3..98a2db783 100644 --- a/tlm_adjoint/checkpoint_schedules/binomial.py +++ b/tlm_adjoint/checkpoint_schedules/binomial.py @@ -1,21 +1,11 @@ -from .schedule import ( - CheckpointSchedule, Clear, Configure, Forward, Reverse, Read, Write, - EndForward, EndReverse) +from checkpoint_schedules import StorageType +from checkpoint_schedules import ( + MultistageCheckpointSchedule as _MultistageCheckpointSchedule, + TwoLevelCheckpointSchedule as _TwoLevelCheckpointSchedule) -import functools -from operator import itemgetter - -try: - import numba - from numba import njit -except ImportError: - numba = None +from .translation import translation - def njit(fn): - @functools.wraps(fn) - def wrapped_fn(*args, **kwargs): - return fn(*args, **kwargs) - return wrapped_fn +import functools __all__ = \ [ @@ -24,69 +14,6 @@ def wrapped_fn(*args, **kwargs): ] -@njit -def n_advance(n, snapshots, *, trajectory="maximum"): - # GW2000 reference: - # Andreas Griewank and Andrea Walther, 'Algorithm 799: revolve: an - # implementation of checkpointing for the reverse or adjoint mode of - # computational differentiation', ACM Transactions on Mathematical - # Software, 26(1), pp. 19--45, 2000, doi: 10.1145/347837.347846 - - if n < 1: - raise ValueError("Require at least one block") - if snapshots <= 0: - raise ValueError("Require at least one snapshot") - - # Discard excess snapshots - snapshots = max(min(snapshots, n - 1), 1) - # Handle limiting cases - if snapshots == 1: - return n - 1 # Minimal storage - elif snapshots == n - 1: - return 1 # Maximal storage - - # Find t as in GW2000 Proposition 1 (note 'm' in GW2000 is 'n' here, and - # 's' in GW2000 is 'snapshots' here). Compute values of beta as in equation - # (1) of GW2000 as a side effect. We must have a minimal rerun of at least - # 2 (the minimal rerun of 1 case is maximal storage, handled above) so we - # start from t = 2. - t = 2 - b_s_tm2 = 1 - b_s_tm1 = snapshots + 1 - b_s_t = ((snapshots + 1) * (snapshots + 2)) // 2 - while b_s_tm1 >= n or n > b_s_t: - t += 1 - b_s_tm2 = b_s_tm1 - b_s_tm1 = b_s_t - b_s_t = (b_s_t * (snapshots + t)) // t - - if trajectory == "maximum": - # Return the maximal step size compatible with Fig. 4 of GW2000 - b_sm1_tm2 = (b_s_tm2 * snapshots) // (snapshots + t - 2) - if n <= b_s_tm1 + b_sm1_tm2: - return n - b_s_tm1 + b_s_tm2 - b_sm1_tm1 = (b_s_tm1 * snapshots) // (snapshots + t - 1) - b_sm2_tm1 = (b_sm1_tm1 * (snapshots - 1)) // (snapshots + t - 2) - if n <= b_s_tm1 + b_sm2_tm1 + b_sm1_tm2: - return b_s_tm2 + b_sm1_tm2 - elif n <= b_s_tm1 + b_sm1_tm1 + b_sm2_tm1: - return n - b_sm1_tm1 - b_sm2_tm1 - else: - return b_s_tm1 - elif trajectory == "revolve": - # GW2000, equation at the bottom of p. 34 - b_sm1_tm1 = (b_s_tm1 * snapshots) // (snapshots + t - 1) - b_sm2_tm1 = (b_sm1_tm1 * (snapshots - 1)) // (snapshots + t - 2) - if n <= b_s_tm1 + b_sm2_tm1: - return b_s_tm2 - elif n < b_s_tm1 + b_sm1_tm1 + b_sm2_tm1: - return n - b_sm1_tm1 - b_sm2_tm1 - else: - return b_s_tm1 - else: - raise ValueError("Unexpected trajectory: '{trajectory:s}'") - - def cache_step(fn): _cache = {} @@ -134,80 +61,7 @@ def optimal_steps(n, s): return n + optimal_extra_steps(n, s) -def allocate_snapshots(max_n, snapshots_in_ram, snapshots_on_disk, *, - write_weight=1.0, read_weight=1.0, delete_weight=0.0, - trajectory="maximum"): - snapshots_in_ram = min(snapshots_in_ram, max_n - 1) - snapshots_on_disk = min(snapshots_on_disk, max_n - 1) - snapshots = min(snapshots_in_ram + snapshots_on_disk, max_n - 1) - weights = [0.0 for _ in range(snapshots)] - - cp_schedule = MultistageCheckpointSchedule(max_n, snapshots, 0, - trajectory=trajectory) - - snapshot_i = -1 - - @functools.singledispatch - def action(cp_action): - raise TypeError(f"Unexpected checkpointing action: {cp_action}") - - @action.register(Read) - def action_read(cp_action): - nonlocal snapshot_i - - if snapshot_i < 0: - raise RuntimeError("Invalid checkpointing state") - weights[snapshot_i] += read_weight - if cp_action.delete: - weights[snapshot_i] += delete_weight - snapshot_i -= 1 - - @action.register(Write) - def action_write(cp_action): - nonlocal snapshot_i - - snapshot_i += 1 - if snapshot_i >= snapshots: - raise RuntimeError("Invalid checkpointing state") - weights[snapshot_i] += write_weight - - @action.register(Clear) - @action.register(Configure) - @action.register(Forward) - @action.register(Reverse) - @action.register(EndForward) - @action.register(EndReverse) - def action_pass(cp_action): - pass - - # Run the schedule, keeping track of the total read/write/delete costs - # associated with each storage location on the stack of checkpointing units - - while True: - cp_action = next(cp_schedule) - action(cp_action) - if isinstance(cp_action, EndReverse): - break - - assert snapshot_i == -1 - - # Allocate the checkpointing units with highest cost to RAM, and the - # remaining units to disk. For read and write costs of one and zero delete - # costs the distribution of storage between RAM and disk is then equivalent - # to that in - # Philipp Stumm and Andrea Walther, 'MultiStage approaches for optimal - # offline checkpointing', SIAM Journal on Scientific Computing, 31(3), - # pp. 1946--1967, 2009, doi: 10.1137/080718036 - - allocation = ["disk" for _ in range(snapshots)] - for i, _ in sorted(enumerate(weights), key=itemgetter(1), - reverse=True)[:snapshots_in_ram]: - allocation[i] = "RAM" - - return tuple(weights), tuple(allocation) - - -class MultistageCheckpointSchedule(CheckpointSchedule): +class MultistageCheckpointSchedule(translation(_MultistageCheckpointSchedule)): """A binomial checkpointing schedule using the approach described in - Andreas Griewank and Andrea Walther, 'Algorithm 799: revolve: an @@ -251,149 +105,8 @@ class MultistageCheckpointSchedule(CheckpointSchedule): function in dolfin-adjoint (see e.g. version 2017.1.0). """ - def __init__(self, max_n, snapshots_in_ram, snapshots_on_disk, *, - trajectory="maximum"): - snapshots_in_ram = min(snapshots_in_ram, max_n - 1) - snapshots_on_disk = min(snapshots_on_disk, max_n - 1) - if snapshots_in_ram == 0: - storage = tuple("disk" for _ in range(snapshots_on_disk)) - elif snapshots_on_disk == 0: - storage = tuple("RAM" for _ in range(snapshots_in_ram)) - else: - _, storage = allocate_snapshots( - max_n, snapshots_in_ram, snapshots_on_disk, - trajectory=trajectory) - - snapshots_in_ram = storage.count("RAM") - snapshots_on_disk = storage.count("disk") - - super().__init__(max_n=max_n) - self._snapshots_in_ram = snapshots_in_ram - self._snapshots_on_disk = snapshots_on_disk - self._storage = storage - self._exhausted = False - self._trajectory = trajectory - - def iter(self): - snapshots = [] - - def write(n): - if len(snapshots) >= self._snapshots_in_ram + self._snapshots_on_disk: # noqa: E501 - raise RuntimeError("Invalid checkpointing state") - snapshots.append(n) - return self._storage[len(snapshots) - 1] - - # Forward - - if self._max_n is None: - raise RuntimeError("Invalid checkpointing state") - while self._n < self._max_n - 1: - yield Configure(True, False) - - n_snapshots = (self._snapshots_in_ram - + self._snapshots_on_disk - - len(snapshots)) - n0 = self._n - n1 = n0 + n_advance(self._max_n - n0, n_snapshots, - trajectory=self._trajectory) - assert n1 > n0 - self._n = n1 - yield Forward(n0, n1) - - cp_storage = write(n0) - yield Write(n0, cp_storage) - yield Clear(True, True) - if self._n != self._max_n - 1: - raise RuntimeError("Invalid checkpointing state") - - # Forward -> reverse - - yield Configure(False, True) - - self._n += 1 - yield Forward(self._n - 1, self._n) - - yield EndForward() - self._r += 1 - yield Reverse(self._n, self._n - 1) - yield Clear(True, True) - - # Reverse - - while self._r < self._max_n: - if len(snapshots) == 0: - raise RuntimeError("Invalid checkpointing state") - cp_n = snapshots[-1] - cp_storage = self._storage[len(snapshots) - 1] - if cp_n == self._max_n - self._r - 1: - snapshots.pop() - self._n = cp_n - yield Read(cp_n, cp_storage, True) - yield Clear(True, True) - else: - self._n = cp_n - yield Read(cp_n, cp_storage, False) - yield Clear(True, True) - - yield Configure(False, False) - - n_snapshots = (self._snapshots_in_ram - + self._snapshots_on_disk - - len(snapshots) + 1) - n0 = self._n - n1 = n0 + n_advance(self._max_n - self._r - n0, n_snapshots, - trajectory=self._trajectory) - assert n1 > n0 - self._n = n1 - yield Forward(n0, n1) - yield Clear(True, True) - - while self._n < self._max_n - self._r - 1: - yield Configure(True, False) - - n_snapshots = (self._snapshots_in_ram - + self._snapshots_on_disk - - len(snapshots)) - n0 = self._n - n1 = n0 + n_advance(self._max_n - self._r - n0, n_snapshots, # noqa: E501 - trajectory=self._trajectory) - assert n1 > n0 - self._n = n1 - yield Forward(n0, n1) - - cp_storage = write(n0) - yield Write(n0, cp_storage) - yield Clear(True, True) - if self._n != self._max_n - self._r - 1: - raise RuntimeError("Invalid checkpointing state") - - yield Configure(False, True) - - self._n += 1 - yield Forward(self._n - 1, self._n) - - self._r += 1 - yield Reverse(self._n, self._n - 1) - yield Clear(True, True) - if self._r != self._max_n: - raise RuntimeError("Invalid checkpointing state") - if len(snapshots) != 0: - raise RuntimeError("Invalid checkpointing state") - - self._exhausted = True - yield EndReverse(True) - - @property - def is_exhausted(self): - return self._exhausted - - @property - def uses_disk_storage(self): - return self._snapshots_on_disk > 0 - - -class TwoLevelCheckpointSchedule(CheckpointSchedule): +class TwoLevelCheckpointSchedule(translation(_TwoLevelCheckpointSchedule)): """A two-level mixed periodic/binomial checkpointing schedule using the approach described in @@ -426,129 +139,8 @@ class TwoLevelCheckpointSchedule(CheckpointSchedule): def __init__(self, period, binomial_snapshots, *, binomial_storage="disk", binomial_trajectory="maximum"): - if period < 1: - raise ValueError("period must be positive") - if binomial_storage not in {"RAM", "disk"}: - raise ValueError("Invalid storage") - - super().__init__() - - self._period = period - self._binomial_snapshots = binomial_snapshots - self._binomial_storage = binomial_storage - self._trajectory = binomial_trajectory - - def iter(self): - # Forward - - while self._max_n is None: - yield Configure(True, False) - if self._max_n is not None: - # Unexpected finalize - raise RuntimeError("Invalid checkpointing state") - n0 = self._n - n1 = n0 + self._period - self._n = n1 - yield Forward(n0, n1) - - # Finalize permitted here - - yield Write(n0, "disk") - yield Clear(True, True) - - yield EndForward() - - while True: - # Reverse - - while self._r < self._max_n: - n = self._max_n - self._r - 1 - n0s = (n // self._period) * self._period - n1s = min(n0s + self._period, self._max_n) - if self._r != self._max_n - n1s: - raise RuntimeError("Invalid checkpointing state") - del n, n1s - - snapshots = [n0s] - while self._r < self._max_n - n0s: - if len(snapshots) == 0: - raise RuntimeError("Invalid checkpointing state") - cp_n = snapshots[-1] - if cp_n == self._max_n - self._r - 1: - snapshots.pop() - self._n = cp_n - if cp_n == n0s: - yield Read(cp_n, "disk", False) - else: - yield Read(cp_n, self._binomial_storage, True) - yield Clear(True, True) - else: - self._n = cp_n - if cp_n == n0s: - yield Read(cp_n, "disk", False) - else: - yield Read(cp_n, self._binomial_storage, False) - yield Clear(True, True) - - yield Configure(False, False) - - n_snapshots = (self._binomial_snapshots + 1 - - len(snapshots) + 1) - n0 = self._n - n1 = n0 + n_advance(self._max_n - self._r - n0, - n_snapshots, - trajectory=self._trajectory) - assert n1 > n0 - self._n = n1 - yield Forward(n0, n1) - yield Clear(True, True) - - while self._n < self._max_n - self._r - 1: - yield Configure(True, False) - - n_snapshots = (self._binomial_snapshots + 1 - - len(snapshots)) - n0 = self._n - n1 = n0 + n_advance(self._max_n - self._r - n0, - n_snapshots, - trajectory=self._trajectory) - assert n1 > n0 - self._n = n1 - yield Forward(n0, n1) - - if len(snapshots) >= self._binomial_snapshots + 1: - raise RuntimeError("Invalid checkpointing " - "state") - snapshots.append(n0) - yield Write(n0, self._binomial_storage) - yield Clear(True, True) - if self._n != self._max_n - self._r - 1: - raise RuntimeError("Invalid checkpointing state") - - yield Configure(False, True) - - self._n += 1 - yield Forward(self._n - 1, self._n) - - self._r += 1 - yield Reverse(self._n, self._n - 1) - yield Clear(True, True) - if self._r != self._max_n - n0s: - raise RuntimeError("Invalid checkpointing state") - if len(snapshots) != 0: - raise RuntimeError("Invalid checkpointing state") - if self._r != self._max_n: - raise RuntimeError("Invalid checkpointing state") - - # Reset for new reverse - - self._r = 0 - yield EndReverse(False) - - @property - def is_exhausted(self): - return False - - @property - def uses_disk_storage(self): - return True + super().__init__( + period, binomial_snapshots, + binomial_storage={"RAM": StorageType.RAM, + "disk": StorageType.DISK}[binomial_storage], + binomial_trajectory=binomial_trajectory) diff --git a/tlm_adjoint/checkpoint_schedules/h_revolve.py b/tlm_adjoint/checkpoint_schedules/h_revolve.py index f22a5b636..8b0e081ec 100644 --- a/tlm_adjoint/checkpoint_schedules/h_revolve.py +++ b/tlm_adjoint/checkpoint_schedules/h_revolve.py @@ -1,8 +1,6 @@ -from .schedule import ( - CheckpointSchedule, Clear, Configure, Forward, Reverse, Read, Write, - EndForward, EndReverse) +from checkpoint_schedules import HRevolve -import logging +from .translation import translation __all__ = \ [ @@ -10,7 +8,7 @@ ] -class HRevolveCheckpointSchedule(CheckpointSchedule): +class HRevolveCheckpointSchedule(translation(HRevolve)): """An H-Revolve checkpointing schedule. Converts from schedules as generated by the H-Revolve library, for the @@ -23,167 +21,21 @@ class HRevolveCheckpointSchedule(CheckpointSchedule): to store in memory. :arg snapshots_on_disk: The maximum number of forward restart checkpoints to store on disk. - :arg wvect: A two element :class:`tuple` defining the write cost associated - with saving a forward restart checkpoint to RAM (first element) and - disk (second element). - :arg rvect: A two element :class:`tuple` defining the read cost associated - with loading a forward restart checkpoint from RAM (first element) and - disk (second element). + :arg wd: The write cost associated with saving a forward restart checkpoint + to disk. + :arg rd: The read cost associated with loading a forward restart checkpoint + from disk. :arg uf: The cost of advancing the forward one step. :arg bf: The cost of advancing the forward one step, storing non-linear dependency data, and then advancing the adjoint over that step. - Remaining keyword arguments are passed to `hrevolve.hrevolve`. - The argument names `snaps_in_ram` and `snaps_on_disk` originate from the corresponding arguments for the `dolfin_adjoint.solving.adj_checkpointing` function in dolfin-adjoint (see e.g. version 2017.1.0). """ def __init__(self, max_n, snapshots_in_ram, snapshots_on_disk, *, - wvect=(0.0, 0.1), rvect=(0.0, 0.1), uf=1.0, ub=2.0, **kwargs): - super().__init__(max_n) - self._snapshots_in_ram = snapshots_in_ram - self._snapshots_on_disk = snapshots_on_disk - self._exhausted = False - - cvect = (snapshots_in_ram, snapshots_on_disk) - import hrevolve - schedule = hrevolve.hrevolve(max_n - 1, cvect, wvect, rvect, - uf=uf, ub=ub, **kwargs) - self._schedule = list(schedule) - - logger = logging.getLogger("tlm_adjoint.checkpointing") - logger.debug(f"H-Revolve schedule: {str(self._schedule):s}") - - def iter(self): - def action(i): - assert i >= 0 and i < len(self._schedule) - action = self._schedule[i] - cp_action = action.type - if cp_action == "Forward": - n_0 = action.index - n_1 = n_0 + 1 - storage = None - elif cp_action == "Forwards": - cp_action = "Forward" - n_0, n_1 = action.index - if n_1 <= n_0: - raise RuntimeError("Invalid schedule") - n_1 += 1 - storage = None - elif cp_action == "Backward": - n_0 = action.index - n_1 = None - storage = None - elif cp_action in {"Read", "Write", "Discard"}: - storage, n_0 = action.index - n_1 = None - storage = {0: "RAM", 1: "disk"}[storage] - else: - raise RuntimeError(f"Unexpected action: {cp_action:s}") - return cp_action, (n_0, n_1, storage) - - if self._max_n is None: - raise RuntimeError("Invalid checkpointing state") - - snapshots = set() - deferred_cp = None - - def write_deferred_cp(): - nonlocal deferred_cp - - if deferred_cp is not None: - snapshots.add(deferred_cp[0]) - yield Write(*deferred_cp) - deferred_cp = None - - for i in range(len(self._schedule)): - cp_action, (n_0, n_1, storage) = action(i) - - if cp_action == "Forward": - if n_0 != self._n: - raise RuntimeError("Invalid checkpointing state") - - yield Clear(True, True) - yield Configure(n_0 not in snapshots, False) - self._n = n_1 - yield Forward(n_0, n_1) - elif cp_action == "Backward": - if n_0 != self._n: - raise RuntimeError("Invalid checkpointing state") - if n_0 != self._max_n - self._r - 1: - raise RuntimeError("Invalid checkpointing state") - - yield from write_deferred_cp() - - yield Clear(True, True) - yield Configure(False, True) - self._n = n_0 + 1 - yield Forward(n_0, n_0 + 1) - if self._n == self._max_n: - if self._r != 0: - raise RuntimeError("Invalid checkpointing state") - yield EndForward() - self._r += 1 - yield Reverse(n_0 + 1, n_0) - elif cp_action == "Read": - if deferred_cp is not None: - raise RuntimeError("Invalid checkpointing state") - - if n_0 == self._max_n - self._r - 1: - cp_delete = True - elif i < len(self._schedule) - 2: - d_cp_action, (d_n_0, _, d_storage) = action(i + 2) - if d_cp_action == "Discard": - if d_n_0 != n_0 or d_storage != storage: - raise RuntimeError("Invalid schedule") - cp_delete = True - else: - cp_delete = False - - yield Clear(True, True) - if cp_delete: - snapshots.remove(n_0) - self._n = n_0 - yield Read(n_0, storage, cp_delete) - elif cp_action == "Write": - if n_0 != self._n: - raise RuntimeError("Invalid checkpointing state") - - yield from write_deferred_cp() - - deferred_cp = (n_0, storage) - - if i > 0: - r_cp_action, (r_n_0, _, _) = action(i - 1) - if r_cp_action == "Read": - if r_n_0 != n_0: - raise RuntimeError("Invalid schedule") - yield from write_deferred_cp() - elif cp_action == "Discard": - if i < 2: - raise RuntimeError("Invalid schedule") - r_cp_action, (r_n_0, _, r_storage) = action(i - 2) - if r_cp_action != "Read" \ - or r_n_0 != n_0 \ - or r_storage != storage: - raise RuntimeError("Invalid schedule") - else: - raise RuntimeError(f"Unexpected action: {cp_action:s}") - - if len(snapshots) != 0: - raise RuntimeError("Invalid checkpointing state") - - yield Clear(True, True) - - self._exhausted = True - yield EndReverse(True) - - @property - def is_exhausted(self): - return self._exhausted - - @property - def uses_disk_storage(self): - return self._snapshots_on_disk > 0 + wd=0.1, rd=0.1, uf=1.0, ub=2.0): + super().__init__( + max_n, snapshots_in_ram, snapshots_on_disk, + uf=uf, ub=ub, wd=wd, rd=rd) diff --git a/tlm_adjoint/checkpoint_schedules/memory.py b/tlm_adjoint/checkpoint_schedules/memory.py index 2d47ad7c7..affbe87cf 100644 --- a/tlm_adjoint/checkpoint_schedules/memory.py +++ b/tlm_adjoint/checkpoint_schedules/memory.py @@ -1,7 +1,7 @@ -from .schedule import ( - CheckpointSchedule, Configure, Forward, Reverse, EndForward, EndReverse) +from checkpoint_schedules import ( + SingleMemoryStorageSchedule as _SingleMemoryStorageSchedule) -import sys +from .translation import translation __all__ = \ [ @@ -9,47 +9,9 @@ ] -class MemoryCheckpointSchedule(CheckpointSchedule): - """A checkpointing schedule where all forward restart and non-linear - dependency data are stored in memory. +class MemoryCheckpointSchedule(translation(_SingleMemoryStorageSchedule)): + """A checkpointing schedule where all non-linear dependency data are stored + in memory. Online, unlimited adjoint calculations permitted. """ - - def iter(self): - # Forward - - if self._max_n is not None: - # Unexpected finalize - raise RuntimeError("Invalid checkpointing state") - yield Configure(True, True) - - while self._max_n is None: - n0 = self._n - n1 = n0 + sys.maxsize - self._n = n1 - yield Forward(n0, n1) - - yield EndForward() - - while True: - if self._r == 0: - # Reverse - - self._r = self._max_n - yield Reverse(self._max_n, 0) - elif self._r == self._max_n: - # Reset for new reverse - - self._r = 0 - yield EndReverse(False) - else: - raise RuntimeError("Invalid checkpointing state") - - @property - def is_exhausted(self): - return False - - @property - def uses_disk_storage(self): - return False diff --git a/tlm_adjoint/checkpoint_schedules/mixed.py b/tlm_adjoint/checkpoint_schedules/mixed.py index aeffc48a2..2d52b99c5 100644 --- a/tlm_adjoint/checkpoint_schedules/mixed.py +++ b/tlm_adjoint/checkpoint_schedules/mixed.py @@ -1,24 +1,11 @@ -from .schedule import ( - CheckpointSchedule, Clear, Configure, Forward, Reverse, Read, Write, - EndForward, EndReverse) +from checkpoint_schedules import StorageType +from checkpoint_schedules import ( + MixedCheckpointSchedule as _MixedCheckpointSchedule) + +from .translation import translation import enum import functools -import numpy as np -import warnings - -try: - import numba - from numba import njit -except ImportError: - numba = None - - def njit(fn): - @functools.wraps(fn) - def wrapped_fn(*args, **kwargs): - return fn(*args, **kwargs) - return wrapped_fn - __all__ = \ [ @@ -102,112 +89,7 @@ def mixed_step_memoization(n, s): return m -_NONE = int(StepType.NONE) -_FORWARD = int(StepType.FORWARD) -_FORWARD_REVERSE = int(StepType.FORWARD_REVERSE) -_WRITE_DATA = int(StepType.WRITE_DATA) -_WRITE_ICS = int(StepType.WRITE_ICS) - - -@njit -def mixed_steps_tabulation(n, s): - schedule = np.zeros((n + 1, s + 1, 3), dtype=np.int_) - schedule[:, :, 0] = _NONE - schedule[:, :, 1] = 0 - schedule[:, :, 2] = -1 - - for s_i in range(s + 1): - schedule[1, s_i, :] = (_FORWARD_REVERSE, 1, 1) - for s_i in range(1, s + 1): - for n_i in range(2, n + 1): - if n_i <= s_i + 1: - schedule[n_i, s_i, :] = (_WRITE_DATA, 1, n_i) - elif s_i == 1: - schedule[n_i, s_i, :] = (_WRITE_ICS, n_i - 1, n_i * (n_i + 1) // 2 - 1) # noqa: E501 - else: - for i in range(2, n_i): - assert schedule[i, s_i, 2] > 0 - assert schedule[n_i - i, s_i - 1, 2] > 0 - m1 = ( - i - + schedule[i, s_i, 2] - + schedule[n_i - i, s_i - 1, 2]) - if schedule[n_i, s_i, 2] < 0 or m1 <= schedule[n_i, s_i, 2]: # noqa: E501 - schedule[n_i, s_i, :] = (_WRITE_ICS, i, m1) - if schedule[n_i, s_i, 2] < 0: - raise RuntimeError("Failed to determine total number of " - "steps") - assert schedule[n_i - 1, s_i - 1, 2] > 0 - m1 = 1 + schedule[n_i - 1, s_i - 1, 2] - if m1 <= schedule[n_i, s_i, 2]: - schedule[n_i, s_i, :] = (_WRITE_DATA, 1, m1) - return schedule - - -def cache_step_0(fn): - _cache = {} - - @functools.wraps(fn) - def wrapped_fn(n, s): - # Avoid some cache misses - s = min(s, n - 2) - if (n, s) not in _cache: - _cache[(n, s)] = fn(n, s) - return _cache[(n, s)] - - return wrapped_fn - - -@cache_step_0 -def mixed_step_memoization_0(n, s): - if s < 0: - raise ValueError("Invalid number of snapshots") - if n < s + 2: - raise ValueError("Invalid number of steps") - - if s == 0: - return (StepType.FORWARD_REVERSE, n, n * (n + 1) // 2 - 1) - else: - m = None - for i in range(1, n): - m1 = ( - i - + mixed_step_memoization(i, s + 1)[2] - + mixed_step_memoization(n - i, s)[2]) - if m is None or m1 <= m[2]: - m = (StepType.FORWARD, i, m1) - if m is None: - raise RuntimeError("Failed to determine total number of steps") - return m - - -@njit -def mixed_steps_tabulation_0(n, s, schedule): - schedule_0 = np.zeros((n + 1, s + 1, 3), dtype=np.int_) - schedule_0[:, :, 0] = _NONE - schedule_0[:, :, 1] = 0 - schedule_0[:, :, 2] = -1 - - for n_i in range(2, n + 1): - schedule_0[n_i, 0, :] = (_FORWARD_REVERSE, n_i, n_i * (n_i + 1) // 2 - 1) # noqa: E501 - for s_i in range(1, s): - for n_i in range(s_i + 2, n + 1): - for i in range(1, n_i): - assert schedule[i, s_i + 1, 2] > 0 - assert schedule[n_i - i, s_i, 2] > 0 - m1 = ( - i - + schedule[i, s_i + 1, 2] - + schedule[n_i - i, s_i, 2]) - if schedule_0[n_i, s_i, 2] < 0 or m1 <= schedule_0[n_i, s_i, 2]: # noqa: E501 - schedule_0[n_i, s_i, :] = (_FORWARD, i, m1) - if schedule_0[n_i, s_i, 2] < 0: - raise RuntimeError("Failed to determine total number of " - "steps") - return schedule_0 - - -class MixedCheckpointSchedule(CheckpointSchedule): +class MixedCheckpointSchedule(translation(_MixedCheckpointSchedule)): """A checkpointing schedule which mixes storage of forward restart data and non-linear dependency data in checkpointing units. Assumes that the data required to restart the forward has the same size as the data required to @@ -228,154 +110,7 @@ class MixedCheckpointSchedule(CheckpointSchedule): """ def __init__(self, max_n, snapshots, *, storage="disk"): - if snapshots < min(1, max_n - 1): - raise ValueError("Invalid number of snapshots") - if storage not in {"RAM", "disk"}: - raise ValueError("Invalid storage") - - super().__init__(max_n) - self._exhausted = False - self._snapshots = min(snapshots, max_n - 1) - self._storage = storage - - def iter(self): - snapshot_n = set() - snapshots = [] - - if self._max_n is None: - raise RuntimeError("Invalid checkpointing state") - - if numba is None: - warnings.warn("Numba not available -- using memoization", - RuntimeWarning) - else: - schedule = mixed_steps_tabulation(self._max_n, self._snapshots) - schedule_0 = mixed_steps_tabulation_0(self._max_n, self._snapshots, schedule) # noqa: E501 - - step_type = StepType.NONE - while True: - while self._n < self._max_n - self._r: - n0 = self._n - if n0 in snapshot_n: - # n0 checkpoint exists - if numba is None: - step_type, n1, _ = mixed_step_memoization_0( - self._max_n - self._r - n0, - self._snapshots - len(snapshots)) - else: - step_type, n1, _ = schedule_0[ - self._max_n - self._r - n0, - self._snapshots - len(snapshots)] - else: - # n0 checkpoint does not exist - if numba is None: - step_type, n1, _ = mixed_step_memoization( - self._max_n - self._r - n0, - self._snapshots - len(snapshots)) - else: - step_type, n1, _ = schedule[ - self._max_n - self._r - n0, - self._snapshots - len(snapshots)] - n1 += n0 - - if step_type == StepType.FORWARD_REVERSE: - if n1 > n0 + 1: - yield Configure(False, False) - self._n = n1 - 1 - yield Forward(n0, n1 - 1) - yield Clear(True, True) - elif n1 <= n0: - raise RuntimeError("Invalid step") - yield Configure(False, True) - self._n += 1 - yield Forward(n1 - 1, n1) - elif step_type == StepType.FORWARD: - if n1 <= n0: - raise RuntimeError("Invalid step") - yield Configure(False, False) - self._n = n1 - yield Forward(n0, n1) - yield Clear(True, True) - elif step_type == StepType.WRITE_DATA: - if n1 != n0 + 1: - raise RuntimeError("Invalid step") - yield Configure(False, True) - self._n = n1 - yield Forward(n0, n1) - if n0 in snapshot_n: - raise RuntimeError("Invalid checkpointing state") - elif len(snapshots) > self._snapshots - 1: - raise RuntimeError("Invalid checkpointing state") - snapshot_n.add(n0) - snapshots.append((StepType.READ_DATA, n0)) - yield Write(n0, self._storage) - yield Clear(True, True) - elif step_type == StepType.WRITE_ICS: - if n1 <= n0 + 1: - raise ValueError("Invalid step") - yield Configure(True, False) - self._n = n1 - yield Forward(n0, n1) - if n0 in snapshot_n: - raise RuntimeError("Invalid checkpointing state") - elif len(snapshots) > self._snapshots - 1: - raise RuntimeError("Invalid checkpointing state") - snapshot_n.add(n0) - snapshots.append((StepType.READ_ICS, n0)) - yield Write(n0, self._storage) - yield Clear(True, True) - else: - raise RuntimeError("Unexpected step type") - if self._n != self._max_n - self._r: - raise RuntimeError("Invalid checkpointing state") - if step_type not in (StepType.FORWARD_REVERSE, StepType.READ_DATA): - raise RuntimeError("Invalid checkpointing state") - - if self._r == 0: - yield EndForward() - - self._r += 1 - yield Reverse(self._max_n - self._r + 1, self._max_n - self._r) - yield Clear(True, True) - - if self._r == self._max_n: - break - - step_type, cp_n = snapshots[-1] - - # Delete if we have (possibly after deleting this checkpoint) - # enough storage left to store all non-linear dependency data - cp_delete = (cp_n >= (self._max_n - self._r - 1 - - (self._snapshots - len(snapshots) + 1))) - if cp_delete: - snapshot_n.remove(cp_n) - snapshots.pop() - - self._n = cp_n - if step_type == StepType.READ_DATA: - # Non-linear dependency data checkpoint - if not cp_delete: - # We cannot advance from a loaded non-linear dependency - # checkpoint, and so we expect to use it immediately - raise RuntimeError("Invalid checkpointing state") - # Note that we cannot in general restart the forward here - self._n += 1 - elif step_type != StepType.READ_ICS: - raise RuntimeError("Invalid checkpointing state") - yield Read(cp_n, self._storage, cp_delete) - if step_type == StepType.READ_ICS: - yield Clear(True, True) - - if len(snapshot_n) > 0 or len(snapshots) > 0: - raise RuntimeError("Invalid checkpointing state") - - self._exhausted = True - yield EndReverse(True) - - @property - def is_exhausted(self): - return self._exhausted - - @property - def uses_disk_storage(self): - return self._max_n > 1 and self._storage == "disk" + super().__init__( + max_n, snapshots, + storage={"RAM": StorageType.RAM, + "disk": StorageType.DISK}[storage]) diff --git a/tlm_adjoint/checkpoint_schedules/none.py b/tlm_adjoint/checkpoint_schedules/none.py index 64287b4a6..bc225bf8b 100644 --- a/tlm_adjoint/checkpoint_schedules/none.py +++ b/tlm_adjoint/checkpoint_schedules/none.py @@ -1,6 +1,7 @@ -from .schedule import CheckpointSchedule, Configure, Forward, EndForward +from checkpoint_schedules import ( + NoneCheckpointSchedule as _NoneCheckpointSchedule) -import sys +from .translation import translation __all__ = \ [ @@ -8,38 +9,9 @@ ] -class NoneCheckpointSchedule(CheckpointSchedule): +class NoneCheckpointSchedule(translation(_NoneCheckpointSchedule)): """A checkpointing schedule for the case where no adjoint calculation is performed. Online, zero adjoint calculations permitted. """ - - def __init__(self): - super().__init__() - self._exhausted = False - - def iter(self): - # Forward - - if self._max_n is not None: - # Unexpected finalize - raise RuntimeError("Invalid checkpointing state") - yield Configure(False, False) - - while self._max_n is None: - n0 = self._n - n1 = n0 + sys.maxsize - self._n = n1 - yield Forward(n0, n1) - - self._exhausted = True - yield EndForward() - - @property - def is_exhausted(self): - return self._exhausted - - @property - def uses_disk_storage(self): - return False diff --git a/tlm_adjoint/checkpoint_schedules/translation.py b/tlm_adjoint/checkpoint_schedules/translation.py new file mode 100644 index 000000000..0c1999e66 --- /dev/null +++ b/tlm_adjoint/checkpoint_schedules/translation.py @@ -0,0 +1,148 @@ +from checkpoint_schedules import StorageType +from checkpoint_schedules import ( + Forward as _Forward, Reverse as _Reverse, Copy as _Copy, Move as _Move, + EndForward as _EndForward, EndReverse as _EndReverse) + +from .schedule import ( + CheckpointSchedule, Configure, Clear, Forward, Reverse, Read, Write, + EndForward, EndReverse) + +import functools + +__all__ = \ + [ + ] + + +def translation(cls): + class Translation(CheckpointSchedule): + def __init__(self, *args, **kwargs): + self._cp_schedule = cls(*args, **kwargs) + super().__init__(self._cp_schedule.max_n) + self._is_exhausted = self._cp_schedule.is_exhausted + + def iter(self): + # Used to ensure that we do not finalize the wrapped scheduler + # while yielding actions associated with a single wrapped action. + # Prevents multiple finalization of the wrapped schedule. + def locked(fn): + @functools.wraps(fn) + def wrapped_fn(cp_action): + max_n = self._max_n + try: + yield from fn(cp_action) + finally: + if self._max_n != max_n: + self._cp_schedule.finalize(self._max_n) + return wrapped_fn + + @functools.singledispatch + @locked + def action(cp_action): + raise TypeError(f"Unexpected action type: {type(cp_action)}") + yield None + + @action.register(_Forward) + @locked + def action_forward(cp_action): + yield Clear(True, True) + yield Configure(cp_action.write_ics, cp_action.write_adj_deps) + self._n = self._cp_schedule.n + yield Forward(cp_action.n0, cp_action.n1) + if cp_action.storage not in {StorageType.NONE, + StorageType.WORK}: + yield Write(cp_action.n0, + {StorageType.RAM: "RAM", + StorageType.DISK: "disk"}[cp_action.storage]) + yield Clear(True, False) + + @action.register(_Reverse) + @locked + def action_reverse(cp_action): + if self._max_n is None: + raise RuntimeError("Invalid checkpointing state") + self._r = self._cp_schedule.r + yield Reverse(cp_action.n1, cp_action.n0) + + @action.register(_Copy) + @locked + def action_copy(cp_action): + if cp_action.to_storage == StorageType.NONE: + pass + elif cp_action.to_storage in {StorageType.RAM, StorageType.DISK}: # noqa: E501 + yield Clear(True, True) + self._n = self._cp_schedule.n + yield Read(cp_action.n, + {StorageType.RAM: "RAM", + StorageType.DISK: "disk"}[cp_action.from_storage], # noqa: E501 + False) + yield Write(cp_action.n0, + {StorageType.RAM: "RAM", + StorageType.DISK: "disk"}[cp_action.to_storage]) # noqa: E501 + yield Clear(True, True) + elif cp_action.to_storage == StorageType.WORK: + yield Clear(True, True) + self._n = self._cp_schedule.n + yield Read(cp_action.n, + {StorageType.RAM: "RAM", + StorageType.DISK: "disk"}[cp_action.from_storage], # noqa: E501 + False) + else: + raise ValueError(f"Unexpected storage type: " + f"{cp_action.to_storage}") + + @action.register(_Move) + @locked + def action_move(cp_action): + if cp_action.to_storage == StorageType.NONE: + pass + elif cp_action.to_storage in {StorageType.RAM, StorageType.DISK}: # noqa: E501 + yield Clear(True, True) + self._n = self._cp_schedule.n + yield Read(cp_action.n, + {StorageType.RAM: "RAM", + StorageType.DISK: "disk"}[cp_action.from_storage], # noqa: E501 + True) + yield Write(cp_action.n0, + {StorageType.RAM: "RAM", + StorageType.DISK: "disk"}[cp_action.to_storage]) # noqa: E501 + yield Clear(True, True) + elif cp_action.to_storage == StorageType.WORK: + yield Clear(True, True) + self._n = self._cp_schedule.n + yield Read(cp_action.n, + {StorageType.RAM: "RAM", + StorageType.DISK: "disk"}[cp_action.from_storage], # noqa: E501 + True) + else: + raise ValueError(f"Unexpected storage type: " + f"{cp_action.to_storage}") + + @action.register(_EndForward) + @locked + def action_end_forward(cp_action): + self._is_exhausted = self._cp_schedule.is_exhausted + yield EndForward() + + @action.register(_EndReverse) + @locked + def action_end_reverse(cp_action): + if self._cp_schedule.is_exhausted: + yield Clear(True, True) + self._r = self._cp_schedule.r + self._is_exhausted = self._cp_schedule.is_exhausted + yield EndReverse(self._cp_schedule.is_exhausted) + + yield Clear(True, True) + while not self._cp_schedule.is_exhausted: + yield from action(next(self._cp_schedule)) + + @property + def is_exhausted(self): + return self._is_exhausted + + @property + def uses_disk_storage(self): + return self._cp_schedule.uses_storage_type(StorageType.DISK) + + return Translation