diff --git a/.github/workflows/test-base.yml b/.github/workflows/test-base.yml
index 1fc39f584..268f8d44b 100644
--- a/.github/workflows/test-base.yml
+++ b/.github/workflows/test-base.yml
@@ -29,6 +29,7 @@ jobs:
sudo apt-get update
sudo apt-get install flake8 python3-h5py python3-numpy python3-pytest python3-pytest-timeout python3-pytest-xdist python3-scipy python3-sympy
python3 -m pip install ruff
+ git clone --depth 1 https://github.com/firedrakeproject/checkpoint_schedules.git
- name: Lint
run: |
cd tlm_adjoint
@@ -36,5 +37,6 @@ jobs:
ruff check
- name: Run tests
run: |
+ export PYTHONPATH=$PWD/checkpoint_schedules:$PYTHONPATH
cd tlm_adjoint
pytest-3 -v -n 2 --timeout=300 --timeout-method=thread
diff --git a/.github/workflows/test-fenics.yml b/.github/workflows/test-fenics.yml
index 905a8a34d..fad5e8e5d 100644
--- a/.github/workflows/test-fenics.yml
+++ b/.github/workflows/test-fenics.yml
@@ -29,6 +29,7 @@ jobs:
sudo apt-get update
sudo apt-get install flake8 python3-dolfin python3-h5py python3-numpy python3-pytest python3-pytest-timeout python3-pytest-xdist python3-scipy python3-sympy
python3 -m pip install ruff
+ git clone --depth 1 https://github.com/firedrakeproject/checkpoint_schedules.git
- name: Lint
run: |
cd tlm_adjoint
@@ -36,5 +37,6 @@ jobs:
ruff check
- name: Run tests
run: |
+ export PYTHONPATH=$PWD/checkpoint_schedules:$PYTHONPATH
cd tlm_adjoint
pytest-3 -v tests/base tests/checkpoint_schedules tests/fenics -n 2 --timeout=300 --timeout-method=thread
diff --git a/.github/workflows/test-firedrake.yml b/.github/workflows/test-firedrake.yml
index 6a248b5da..8a7a0fed0 100644
--- a/.github/workflows/test-firedrake.yml
+++ b/.github/workflows/test-firedrake.yml
@@ -31,6 +31,7 @@ jobs:
run: |
. /home/firedrake/firedrake/bin/activate
python3 -m pip install jax jaxlib ruff pytest-timeout pytest-xdist
+ git clone --depth 1 https://github.com/firedrakeproject/checkpoint_schedules.git
- name: Lint
run: |
. /home/firedrake/firedrake/bin/activate
@@ -40,6 +41,7 @@ jobs:
- name: Run tests
run: |
. /home/firedrake/firedrake/bin/activate
+ export PYTHONPATH=$PWD/checkpoint_schedules:$PYTHONPATH
cd tlm_adjoint
pytest -v tests/base tests/checkpoint_schedules tests/firedrake -n 2 --timeout=300 --timeout-method=thread
test-complex:
@@ -60,6 +62,7 @@ jobs:
run: |
. /home/firedrake/firedrake/bin/activate
python3 -m pip install jax jaxlib ruff pytest-timeout pytest-xdist
+ git clone --depth 1 https://github.com/firedrakeproject/checkpoint_schedules.git
- name: Lint
run: |
. /home/firedrake/firedrake/bin/activate
@@ -69,5 +72,6 @@ jobs:
- name: Run tests
run: |
. /home/firedrake/firedrake/bin/activate
+ export PYTHONPATH=$PWD/checkpoint_schedules:$PYTHONPATH
cd tlm_adjoint
pytest -v tests/base tests/checkpoint_schedules tests/firedrake -n 2 --timeout=300 --timeout-method=thread
diff --git a/docs/source/dependencies.rst b/docs/source/dependencies.rst
index d3129a0d8..a2d9f3fa9 100644
--- a/docs/source/dependencies.rst
+++ b/docs/source/dependencies.rst
@@ -8,6 +8,7 @@ tlm_adjoint requires:
- `NumPy `_
- `SymPy `_
+ - `checkpoint_schedules `_
Backend dependencies
--------------------
diff --git a/tests/checkpoint_schedules/test_validity.py b/tests/checkpoint_schedules/test_validity.py
index f4bda458f..2a718562b 100644
--- a/tests/checkpoint_schedules/test_validity.py
+++ b/tests/checkpoint_schedules/test_validity.py
@@ -11,10 +11,6 @@
import functools
import pytest
-try:
- import hrevolve
-except ImportError:
- hrevolve = None
try:
import mpi4py.MPI as MPI
except ImportError:
@@ -72,10 +68,7 @@ def mixed(n, s):
(two_level, {"period": 2}),
(two_level, {"period": 7}),
(two_level, {"period": 10}),
- pytest.param(
- h_revolve, {},
- marks=pytest.mark.skipif(hrevolve is None,
- reason="H-Revolve not available")),
+ (h_revolve, {}),
(mixed, {})])
@pytest.mark.parametrize("n, S", [(1, (0,)),
(2, (1,)),
diff --git a/tests/fenics/test_equations.py b/tests/fenics/test_equations.py
index 85fff736d..9dfa5de67 100644
--- a/tests/fenics/test_equations.py
+++ b/tests/fenics/test_equations.py
@@ -864,10 +864,8 @@ def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
x_0, _, adj_x_0, J = forward(y)
stop_manager()
- assert len(manager()._cp._refs) == 3
- assert tuple(manager()._cp._refs.keys()) == (var_id(y),
- var_id(zero),
- var_id(adj_x_0))
+ assert len(manager()._cp._refs) == 1
+ assert tuple(manager()._cp._refs.keys()) == (var_id(adj_x_0),)
assert len(manager()._cp._cp) == 0
if test_adj_ic:
assert len(manager()._cp._data) == 8
@@ -877,7 +875,7 @@ def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
assert len(manager()._cp._data) == 9
assert tuple(map(len, manager()._cp._data.values())) \
== (0, 0, 0, 1, 0, 0, 2, 0, 0)
- assert len(manager()._cp._storage) == 5
+ assert len(manager()._cp._storage) == 3
dJdx_0, dJdy = compute_gradient(J, [x_0, y])
assert var_linf_norm(dJdx_0) == 0.0
diff --git a/tests/fenics/test_models.py b/tests/fenics/test_models.py
index 4f0ae06cd..96b49ddb7 100644
--- a/tests/fenics/test_models.py
+++ b/tests/fenics/test_models.py
@@ -6,11 +6,6 @@
import numpy as np
import pytest
-try:
- import hrevolve
-except ImportError:
- hrevolve = None
-
pytestmark = pytest.mark.skipif(
DEFAULT_COMM.size not in {1, 4},
reason="tests must be run in serial, or with 4 processes")
@@ -86,10 +81,7 @@ def diffusion_ref():
"snaps_in_ram": 2}),
("multistage", {"format": "hdf5", "snaps_on_disk": 1,
"snaps_in_ram": 2}),
- pytest.param(
- "H-Revolve", {"snapshots_on_disk": 1, "snapshots_in_ram": 2},
- marks=pytest.mark.skipif(hrevolve is None,
- reason="H-Revolve not available")),
+ ("H-Revolve", {"snapshots_on_disk": 1, "snapshots_in_ram": 2}),
("mixed", {"snapshots": 2, "storage": "disk"})])
@seed_test
def test_oscillator(setup_test, test_leaks,
diff --git a/tests/firedrake/test_equations.py b/tests/firedrake/test_equations.py
index 5d153a260..1251c7ce9 100644
--- a/tests/firedrake/test_equations.py
+++ b/tests/firedrake/test_equations.py
@@ -1115,10 +1115,8 @@ def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
x_0, _, adj_x_0, J = forward(y)
stop_manager()
- assert len(manager()._cp._refs) == 3
- assert tuple(manager()._cp._refs.keys()) == (var_id(y),
- var_id(zero),
- var_id(adj_x_0))
+ assert len(manager()._cp._refs) == 1
+ assert tuple(manager()._cp._refs.keys()) == (var_id(adj_x_0),)
assert len(manager()._cp._cp) == 0
if test_adj_ic:
assert len(manager()._cp._data) == 9
@@ -1128,7 +1126,7 @@ def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
assert len(manager()._cp._data) == 10
assert tuple(map(len, manager()._cp._data.values())) \
== (0, 0, 0, 0, 1, 0, 0, 2, 0, 0)
- assert len(manager()._cp._storage) == 5
+ assert len(manager()._cp._storage) == 3
dJdx_0, dJdy = compute_gradient(J, [x_0, y])
assert var_linf_norm(dJdx_0) == 0.0
diff --git a/tests/firedrake/test_models.py b/tests/firedrake/test_models.py
index 0db70effa..b97743680 100644
--- a/tests/firedrake/test_models.py
+++ b/tests/firedrake/test_models.py
@@ -6,11 +6,6 @@
import numpy as np
import pytest
-try:
- import hrevolve
-except ImportError:
- hrevolve = None
-
pytestmark = pytest.mark.skipif(
DEFAULT_COMM.size not in {1, 4},
reason="tests must be run in serial, or with 4 processes")
@@ -84,10 +79,7 @@ def diffusion_ref():
"snaps_in_ram": 2}),
("multistage", {"format": "hdf5", "snaps_on_disk": 1,
"snaps_in_ram": 2}),
- pytest.param(
- "H-Revolve", {"snapshots_on_disk": 1, "snapshots_in_ram": 2},
- marks=pytest.mark.skipif(hrevolve is None,
- reason="H-Revolve not available")),
+ ("H-Revolve", {"snapshots_on_disk": 1, "snapshots_in_ram": 2}),
("mixed", {"snapshots": 2, "storage": "disk"})])
@seed_test
def test_oscillator(setup_test, test_leaks,
diff --git a/tlm_adjoint/checkpoint_schedules/binomial.py b/tlm_adjoint/checkpoint_schedules/binomial.py
index 9907a03b3..98a2db783 100644
--- a/tlm_adjoint/checkpoint_schedules/binomial.py
+++ b/tlm_adjoint/checkpoint_schedules/binomial.py
@@ -1,21 +1,11 @@
-from .schedule import (
- CheckpointSchedule, Clear, Configure, Forward, Reverse, Read, Write,
- EndForward, EndReverse)
+from checkpoint_schedules import StorageType
+from checkpoint_schedules import (
+ MultistageCheckpointSchedule as _MultistageCheckpointSchedule,
+ TwoLevelCheckpointSchedule as _TwoLevelCheckpointSchedule)
-import functools
-from operator import itemgetter
-
-try:
- import numba
- from numba import njit
-except ImportError:
- numba = None
+from .translation import translation
- def njit(fn):
- @functools.wraps(fn)
- def wrapped_fn(*args, **kwargs):
- return fn(*args, **kwargs)
- return wrapped_fn
+import functools
__all__ = \
[
@@ -24,69 +14,6 @@ def wrapped_fn(*args, **kwargs):
]
-@njit
-def n_advance(n, snapshots, *, trajectory="maximum"):
- # GW2000 reference:
- # Andreas Griewank and Andrea Walther, 'Algorithm 799: revolve: an
- # implementation of checkpointing for the reverse or adjoint mode of
- # computational differentiation', ACM Transactions on Mathematical
- # Software, 26(1), pp. 19--45, 2000, doi: 10.1145/347837.347846
-
- if n < 1:
- raise ValueError("Require at least one block")
- if snapshots <= 0:
- raise ValueError("Require at least one snapshot")
-
- # Discard excess snapshots
- snapshots = max(min(snapshots, n - 1), 1)
- # Handle limiting cases
- if snapshots == 1:
- return n - 1 # Minimal storage
- elif snapshots == n - 1:
- return 1 # Maximal storage
-
- # Find t as in GW2000 Proposition 1 (note 'm' in GW2000 is 'n' here, and
- # 's' in GW2000 is 'snapshots' here). Compute values of beta as in equation
- # (1) of GW2000 as a side effect. We must have a minimal rerun of at least
- # 2 (the minimal rerun of 1 case is maximal storage, handled above) so we
- # start from t = 2.
- t = 2
- b_s_tm2 = 1
- b_s_tm1 = snapshots + 1
- b_s_t = ((snapshots + 1) * (snapshots + 2)) // 2
- while b_s_tm1 >= n or n > b_s_t:
- t += 1
- b_s_tm2 = b_s_tm1
- b_s_tm1 = b_s_t
- b_s_t = (b_s_t * (snapshots + t)) // t
-
- if trajectory == "maximum":
- # Return the maximal step size compatible with Fig. 4 of GW2000
- b_sm1_tm2 = (b_s_tm2 * snapshots) // (snapshots + t - 2)
- if n <= b_s_tm1 + b_sm1_tm2:
- return n - b_s_tm1 + b_s_tm2
- b_sm1_tm1 = (b_s_tm1 * snapshots) // (snapshots + t - 1)
- b_sm2_tm1 = (b_sm1_tm1 * (snapshots - 1)) // (snapshots + t - 2)
- if n <= b_s_tm1 + b_sm2_tm1 + b_sm1_tm2:
- return b_s_tm2 + b_sm1_tm2
- elif n <= b_s_tm1 + b_sm1_tm1 + b_sm2_tm1:
- return n - b_sm1_tm1 - b_sm2_tm1
- else:
- return b_s_tm1
- elif trajectory == "revolve":
- # GW2000, equation at the bottom of p. 34
- b_sm1_tm1 = (b_s_tm1 * snapshots) // (snapshots + t - 1)
- b_sm2_tm1 = (b_sm1_tm1 * (snapshots - 1)) // (snapshots + t - 2)
- if n <= b_s_tm1 + b_sm2_tm1:
- return b_s_tm2
- elif n < b_s_tm1 + b_sm1_tm1 + b_sm2_tm1:
- return n - b_sm1_tm1 - b_sm2_tm1
- else:
- return b_s_tm1
- else:
- raise ValueError("Unexpected trajectory: '{trajectory:s}'")
-
-
def cache_step(fn):
_cache = {}
@@ -134,80 +61,7 @@ def optimal_steps(n, s):
return n + optimal_extra_steps(n, s)
-def allocate_snapshots(max_n, snapshots_in_ram, snapshots_on_disk, *,
- write_weight=1.0, read_weight=1.0, delete_weight=0.0,
- trajectory="maximum"):
- snapshots_in_ram = min(snapshots_in_ram, max_n - 1)
- snapshots_on_disk = min(snapshots_on_disk, max_n - 1)
- snapshots = min(snapshots_in_ram + snapshots_on_disk, max_n - 1)
- weights = [0.0 for _ in range(snapshots)]
-
- cp_schedule = MultistageCheckpointSchedule(max_n, snapshots, 0,
- trajectory=trajectory)
-
- snapshot_i = -1
-
- @functools.singledispatch
- def action(cp_action):
- raise TypeError(f"Unexpected checkpointing action: {cp_action}")
-
- @action.register(Read)
- def action_read(cp_action):
- nonlocal snapshot_i
-
- if snapshot_i < 0:
- raise RuntimeError("Invalid checkpointing state")
- weights[snapshot_i] += read_weight
- if cp_action.delete:
- weights[snapshot_i] += delete_weight
- snapshot_i -= 1
-
- @action.register(Write)
- def action_write(cp_action):
- nonlocal snapshot_i
-
- snapshot_i += 1
- if snapshot_i >= snapshots:
- raise RuntimeError("Invalid checkpointing state")
- weights[snapshot_i] += write_weight
-
- @action.register(Clear)
- @action.register(Configure)
- @action.register(Forward)
- @action.register(Reverse)
- @action.register(EndForward)
- @action.register(EndReverse)
- def action_pass(cp_action):
- pass
-
- # Run the schedule, keeping track of the total read/write/delete costs
- # associated with each storage location on the stack of checkpointing units
-
- while True:
- cp_action = next(cp_schedule)
- action(cp_action)
- if isinstance(cp_action, EndReverse):
- break
-
- assert snapshot_i == -1
-
- # Allocate the checkpointing units with highest cost to RAM, and the
- # remaining units to disk. For read and write costs of one and zero delete
- # costs the distribution of storage between RAM and disk is then equivalent
- # to that in
- # Philipp Stumm and Andrea Walther, 'MultiStage approaches for optimal
- # offline checkpointing', SIAM Journal on Scientific Computing, 31(3),
- # pp. 1946--1967, 2009, doi: 10.1137/080718036
-
- allocation = ["disk" for _ in range(snapshots)]
- for i, _ in sorted(enumerate(weights), key=itemgetter(1),
- reverse=True)[:snapshots_in_ram]:
- allocation[i] = "RAM"
-
- return tuple(weights), tuple(allocation)
-
-
-class MultistageCheckpointSchedule(CheckpointSchedule):
+class MultistageCheckpointSchedule(translation(_MultistageCheckpointSchedule)):
"""A binomial checkpointing schedule using the approach described in
- Andreas Griewank and Andrea Walther, 'Algorithm 799: revolve: an
@@ -251,149 +105,8 @@ class MultistageCheckpointSchedule(CheckpointSchedule):
function in dolfin-adjoint (see e.g. version 2017.1.0).
"""
- def __init__(self, max_n, snapshots_in_ram, snapshots_on_disk, *,
- trajectory="maximum"):
- snapshots_in_ram = min(snapshots_in_ram, max_n - 1)
- snapshots_on_disk = min(snapshots_on_disk, max_n - 1)
- if snapshots_in_ram == 0:
- storage = tuple("disk" for _ in range(snapshots_on_disk))
- elif snapshots_on_disk == 0:
- storage = tuple("RAM" for _ in range(snapshots_in_ram))
- else:
- _, storage = allocate_snapshots(
- max_n, snapshots_in_ram, snapshots_on_disk,
- trajectory=trajectory)
-
- snapshots_in_ram = storage.count("RAM")
- snapshots_on_disk = storage.count("disk")
-
- super().__init__(max_n=max_n)
- self._snapshots_in_ram = snapshots_in_ram
- self._snapshots_on_disk = snapshots_on_disk
- self._storage = storage
- self._exhausted = False
- self._trajectory = trajectory
-
- def iter(self):
- snapshots = []
-
- def write(n):
- if len(snapshots) >= self._snapshots_in_ram + self._snapshots_on_disk: # noqa: E501
- raise RuntimeError("Invalid checkpointing state")
- snapshots.append(n)
- return self._storage[len(snapshots) - 1]
-
- # Forward
-
- if self._max_n is None:
- raise RuntimeError("Invalid checkpointing state")
- while self._n < self._max_n - 1:
- yield Configure(True, False)
-
- n_snapshots = (self._snapshots_in_ram
- + self._snapshots_on_disk
- - len(snapshots))
- n0 = self._n
- n1 = n0 + n_advance(self._max_n - n0, n_snapshots,
- trajectory=self._trajectory)
- assert n1 > n0
- self._n = n1
- yield Forward(n0, n1)
-
- cp_storage = write(n0)
- yield Write(n0, cp_storage)
- yield Clear(True, True)
- if self._n != self._max_n - 1:
- raise RuntimeError("Invalid checkpointing state")
-
- # Forward -> reverse
-
- yield Configure(False, True)
-
- self._n += 1
- yield Forward(self._n - 1, self._n)
-
- yield EndForward()
- self._r += 1
- yield Reverse(self._n, self._n - 1)
- yield Clear(True, True)
-
- # Reverse
-
- while self._r < self._max_n:
- if len(snapshots) == 0:
- raise RuntimeError("Invalid checkpointing state")
- cp_n = snapshots[-1]
- cp_storage = self._storage[len(snapshots) - 1]
- if cp_n == self._max_n - self._r - 1:
- snapshots.pop()
- self._n = cp_n
- yield Read(cp_n, cp_storage, True)
- yield Clear(True, True)
- else:
- self._n = cp_n
- yield Read(cp_n, cp_storage, False)
- yield Clear(True, True)
-
- yield Configure(False, False)
-
- n_snapshots = (self._snapshots_in_ram
- + self._snapshots_on_disk
- - len(snapshots) + 1)
- n0 = self._n
- n1 = n0 + n_advance(self._max_n - self._r - n0, n_snapshots,
- trajectory=self._trajectory)
- assert n1 > n0
- self._n = n1
- yield Forward(n0, n1)
- yield Clear(True, True)
-
- while self._n < self._max_n - self._r - 1:
- yield Configure(True, False)
-
- n_snapshots = (self._snapshots_in_ram
- + self._snapshots_on_disk
- - len(snapshots))
- n0 = self._n
- n1 = n0 + n_advance(self._max_n - self._r - n0, n_snapshots, # noqa: E501
- trajectory=self._trajectory)
- assert n1 > n0
- self._n = n1
- yield Forward(n0, n1)
-
- cp_storage = write(n0)
- yield Write(n0, cp_storage)
- yield Clear(True, True)
- if self._n != self._max_n - self._r - 1:
- raise RuntimeError("Invalid checkpointing state")
-
- yield Configure(False, True)
-
- self._n += 1
- yield Forward(self._n - 1, self._n)
-
- self._r += 1
- yield Reverse(self._n, self._n - 1)
- yield Clear(True, True)
- if self._r != self._max_n:
- raise RuntimeError("Invalid checkpointing state")
- if len(snapshots) != 0:
- raise RuntimeError("Invalid checkpointing state")
-
- self._exhausted = True
- yield EndReverse(True)
-
- @property
- def is_exhausted(self):
- return self._exhausted
-
- @property
- def uses_disk_storage(self):
- return self._snapshots_on_disk > 0
-
-
-class TwoLevelCheckpointSchedule(CheckpointSchedule):
+class TwoLevelCheckpointSchedule(translation(_TwoLevelCheckpointSchedule)):
"""A two-level mixed periodic/binomial checkpointing schedule using the
approach described in
@@ -426,129 +139,8 @@ class TwoLevelCheckpointSchedule(CheckpointSchedule):
def __init__(self, period, binomial_snapshots, *,
binomial_storage="disk",
binomial_trajectory="maximum"):
- if period < 1:
- raise ValueError("period must be positive")
- if binomial_storage not in {"RAM", "disk"}:
- raise ValueError("Invalid storage")
-
- super().__init__()
-
- self._period = period
- self._binomial_snapshots = binomial_snapshots
- self._binomial_storage = binomial_storage
- self._trajectory = binomial_trajectory
-
- def iter(self):
- # Forward
-
- while self._max_n is None:
- yield Configure(True, False)
- if self._max_n is not None:
- # Unexpected finalize
- raise RuntimeError("Invalid checkpointing state")
- n0 = self._n
- n1 = n0 + self._period
- self._n = n1
- yield Forward(n0, n1)
-
- # Finalize permitted here
-
- yield Write(n0, "disk")
- yield Clear(True, True)
-
- yield EndForward()
-
- while True:
- # Reverse
-
- while self._r < self._max_n:
- n = self._max_n - self._r - 1
- n0s = (n // self._period) * self._period
- n1s = min(n0s + self._period, self._max_n)
- if self._r != self._max_n - n1s:
- raise RuntimeError("Invalid checkpointing state")
- del n, n1s
-
- snapshots = [n0s]
- while self._r < self._max_n - n0s:
- if len(snapshots) == 0:
- raise RuntimeError("Invalid checkpointing state")
- cp_n = snapshots[-1]
- if cp_n == self._max_n - self._r - 1:
- snapshots.pop()
- self._n = cp_n
- if cp_n == n0s:
- yield Read(cp_n, "disk", False)
- else:
- yield Read(cp_n, self._binomial_storage, True)
- yield Clear(True, True)
- else:
- self._n = cp_n
- if cp_n == n0s:
- yield Read(cp_n, "disk", False)
- else:
- yield Read(cp_n, self._binomial_storage, False)
- yield Clear(True, True)
-
- yield Configure(False, False)
-
- n_snapshots = (self._binomial_snapshots + 1
- - len(snapshots) + 1)
- n0 = self._n
- n1 = n0 + n_advance(self._max_n - self._r - n0,
- n_snapshots,
- trajectory=self._trajectory)
- assert n1 > n0
- self._n = n1
- yield Forward(n0, n1)
- yield Clear(True, True)
-
- while self._n < self._max_n - self._r - 1:
- yield Configure(True, False)
-
- n_snapshots = (self._binomial_snapshots + 1
- - len(snapshots))
- n0 = self._n
- n1 = n0 + n_advance(self._max_n - self._r - n0,
- n_snapshots,
- trajectory=self._trajectory)
- assert n1 > n0
- self._n = n1
- yield Forward(n0, n1)
-
- if len(snapshots) >= self._binomial_snapshots + 1:
- raise RuntimeError("Invalid checkpointing "
- "state")
- snapshots.append(n0)
- yield Write(n0, self._binomial_storage)
- yield Clear(True, True)
- if self._n != self._max_n - self._r - 1:
- raise RuntimeError("Invalid checkpointing state")
-
- yield Configure(False, True)
-
- self._n += 1
- yield Forward(self._n - 1, self._n)
-
- self._r += 1
- yield Reverse(self._n, self._n - 1)
- yield Clear(True, True)
- if self._r != self._max_n - n0s:
- raise RuntimeError("Invalid checkpointing state")
- if len(snapshots) != 0:
- raise RuntimeError("Invalid checkpointing state")
- if self._r != self._max_n:
- raise RuntimeError("Invalid checkpointing state")
-
- # Reset for new reverse
-
- self._r = 0
- yield EndReverse(False)
-
- @property
- def is_exhausted(self):
- return False
-
- @property
- def uses_disk_storage(self):
- return True
+ super().__init__(
+ period, binomial_snapshots,
+ binomial_storage={"RAM": StorageType.RAM,
+ "disk": StorageType.DISK}[binomial_storage],
+ binomial_trajectory=binomial_trajectory)
diff --git a/tlm_adjoint/checkpoint_schedules/h_revolve.py b/tlm_adjoint/checkpoint_schedules/h_revolve.py
index f22a5b636..8b0e081ec 100644
--- a/tlm_adjoint/checkpoint_schedules/h_revolve.py
+++ b/tlm_adjoint/checkpoint_schedules/h_revolve.py
@@ -1,8 +1,6 @@
-from .schedule import (
- CheckpointSchedule, Clear, Configure, Forward, Reverse, Read, Write,
- EndForward, EndReverse)
+from checkpoint_schedules import HRevolve
-import logging
+from .translation import translation
__all__ = \
[
@@ -10,7 +8,7 @@
]
-class HRevolveCheckpointSchedule(CheckpointSchedule):
+class HRevolveCheckpointSchedule(translation(HRevolve)):
"""An H-Revolve checkpointing schedule.
Converts from schedules as generated by the H-Revolve library, for the
@@ -23,167 +21,21 @@ class HRevolveCheckpointSchedule(CheckpointSchedule):
to store in memory.
:arg snapshots_on_disk: The maximum number of forward restart checkpoints
to store on disk.
- :arg wvect: A two element :class:`tuple` defining the write cost associated
- with saving a forward restart checkpoint to RAM (first element) and
- disk (second element).
- :arg rvect: A two element :class:`tuple` defining the read cost associated
- with loading a forward restart checkpoint from RAM (first element) and
- disk (second element).
+ :arg wd: The write cost associated with saving a forward restart checkpoint
+ to disk.
+ :arg rd: The read cost associated with loading a forward restart checkpoint
+ from disk.
:arg uf: The cost of advancing the forward one step.
:arg bf: The cost of advancing the forward one step, storing non-linear
dependency data, and then advancing the adjoint over that step.
- Remaining keyword arguments are passed to `hrevolve.hrevolve`.
-
The argument names `snaps_in_ram` and `snaps_on_disk` originate from the
corresponding arguments for the `dolfin_adjoint.solving.adj_checkpointing`
function in dolfin-adjoint (see e.g. version 2017.1.0).
"""
def __init__(self, max_n, snapshots_in_ram, snapshots_on_disk, *,
- wvect=(0.0, 0.1), rvect=(0.0, 0.1), uf=1.0, ub=2.0, **kwargs):
- super().__init__(max_n)
- self._snapshots_in_ram = snapshots_in_ram
- self._snapshots_on_disk = snapshots_on_disk
- self._exhausted = False
-
- cvect = (snapshots_in_ram, snapshots_on_disk)
- import hrevolve
- schedule = hrevolve.hrevolve(max_n - 1, cvect, wvect, rvect,
- uf=uf, ub=ub, **kwargs)
- self._schedule = list(schedule)
-
- logger = logging.getLogger("tlm_adjoint.checkpointing")
- logger.debug(f"H-Revolve schedule: {str(self._schedule):s}")
-
- def iter(self):
- def action(i):
- assert i >= 0 and i < len(self._schedule)
- action = self._schedule[i]
- cp_action = action.type
- if cp_action == "Forward":
- n_0 = action.index
- n_1 = n_0 + 1
- storage = None
- elif cp_action == "Forwards":
- cp_action = "Forward"
- n_0, n_1 = action.index
- if n_1 <= n_0:
- raise RuntimeError("Invalid schedule")
- n_1 += 1
- storage = None
- elif cp_action == "Backward":
- n_0 = action.index
- n_1 = None
- storage = None
- elif cp_action in {"Read", "Write", "Discard"}:
- storage, n_0 = action.index
- n_1 = None
- storage = {0: "RAM", 1: "disk"}[storage]
- else:
- raise RuntimeError(f"Unexpected action: {cp_action:s}")
- return cp_action, (n_0, n_1, storage)
-
- if self._max_n is None:
- raise RuntimeError("Invalid checkpointing state")
-
- snapshots = set()
- deferred_cp = None
-
- def write_deferred_cp():
- nonlocal deferred_cp
-
- if deferred_cp is not None:
- snapshots.add(deferred_cp[0])
- yield Write(*deferred_cp)
- deferred_cp = None
-
- for i in range(len(self._schedule)):
- cp_action, (n_0, n_1, storage) = action(i)
-
- if cp_action == "Forward":
- if n_0 != self._n:
- raise RuntimeError("Invalid checkpointing state")
-
- yield Clear(True, True)
- yield Configure(n_0 not in snapshots, False)
- self._n = n_1
- yield Forward(n_0, n_1)
- elif cp_action == "Backward":
- if n_0 != self._n:
- raise RuntimeError("Invalid checkpointing state")
- if n_0 != self._max_n - self._r - 1:
- raise RuntimeError("Invalid checkpointing state")
-
- yield from write_deferred_cp()
-
- yield Clear(True, True)
- yield Configure(False, True)
- self._n = n_0 + 1
- yield Forward(n_0, n_0 + 1)
- if self._n == self._max_n:
- if self._r != 0:
- raise RuntimeError("Invalid checkpointing state")
- yield EndForward()
- self._r += 1
- yield Reverse(n_0 + 1, n_0)
- elif cp_action == "Read":
- if deferred_cp is not None:
- raise RuntimeError("Invalid checkpointing state")
-
- if n_0 == self._max_n - self._r - 1:
- cp_delete = True
- elif i < len(self._schedule) - 2:
- d_cp_action, (d_n_0, _, d_storage) = action(i + 2)
- if d_cp_action == "Discard":
- if d_n_0 != n_0 or d_storage != storage:
- raise RuntimeError("Invalid schedule")
- cp_delete = True
- else:
- cp_delete = False
-
- yield Clear(True, True)
- if cp_delete:
- snapshots.remove(n_0)
- self._n = n_0
- yield Read(n_0, storage, cp_delete)
- elif cp_action == "Write":
- if n_0 != self._n:
- raise RuntimeError("Invalid checkpointing state")
-
- yield from write_deferred_cp()
-
- deferred_cp = (n_0, storage)
-
- if i > 0:
- r_cp_action, (r_n_0, _, _) = action(i - 1)
- if r_cp_action == "Read":
- if r_n_0 != n_0:
- raise RuntimeError("Invalid schedule")
- yield from write_deferred_cp()
- elif cp_action == "Discard":
- if i < 2:
- raise RuntimeError("Invalid schedule")
- r_cp_action, (r_n_0, _, r_storage) = action(i - 2)
- if r_cp_action != "Read" \
- or r_n_0 != n_0 \
- or r_storage != storage:
- raise RuntimeError("Invalid schedule")
- else:
- raise RuntimeError(f"Unexpected action: {cp_action:s}")
-
- if len(snapshots) != 0:
- raise RuntimeError("Invalid checkpointing state")
-
- yield Clear(True, True)
-
- self._exhausted = True
- yield EndReverse(True)
-
- @property
- def is_exhausted(self):
- return self._exhausted
-
- @property
- def uses_disk_storage(self):
- return self._snapshots_on_disk > 0
+ wd=0.1, rd=0.1, uf=1.0, ub=2.0):
+ super().__init__(
+ max_n, snapshots_in_ram, snapshots_on_disk,
+ uf=uf, ub=ub, wd=wd, rd=rd)
diff --git a/tlm_adjoint/checkpoint_schedules/memory.py b/tlm_adjoint/checkpoint_schedules/memory.py
index 2d47ad7c7..affbe87cf 100644
--- a/tlm_adjoint/checkpoint_schedules/memory.py
+++ b/tlm_adjoint/checkpoint_schedules/memory.py
@@ -1,7 +1,7 @@
-from .schedule import (
- CheckpointSchedule, Configure, Forward, Reverse, EndForward, EndReverse)
+from checkpoint_schedules import (
+ SingleMemoryStorageSchedule as _SingleMemoryStorageSchedule)
-import sys
+from .translation import translation
__all__ = \
[
@@ -9,47 +9,9 @@
]
-class MemoryCheckpointSchedule(CheckpointSchedule):
- """A checkpointing schedule where all forward restart and non-linear
- dependency data are stored in memory.
+class MemoryCheckpointSchedule(translation(_SingleMemoryStorageSchedule)):
+ """A checkpointing schedule where all non-linear dependency data are stored
+ in memory.
Online, unlimited adjoint calculations permitted.
"""
-
- def iter(self):
- # Forward
-
- if self._max_n is not None:
- # Unexpected finalize
- raise RuntimeError("Invalid checkpointing state")
- yield Configure(True, True)
-
- while self._max_n is None:
- n0 = self._n
- n1 = n0 + sys.maxsize
- self._n = n1
- yield Forward(n0, n1)
-
- yield EndForward()
-
- while True:
- if self._r == 0:
- # Reverse
-
- self._r = self._max_n
- yield Reverse(self._max_n, 0)
- elif self._r == self._max_n:
- # Reset for new reverse
-
- self._r = 0
- yield EndReverse(False)
- else:
- raise RuntimeError("Invalid checkpointing state")
-
- @property
- def is_exhausted(self):
- return False
-
- @property
- def uses_disk_storage(self):
- return False
diff --git a/tlm_adjoint/checkpoint_schedules/mixed.py b/tlm_adjoint/checkpoint_schedules/mixed.py
index aeffc48a2..2d52b99c5 100644
--- a/tlm_adjoint/checkpoint_schedules/mixed.py
+++ b/tlm_adjoint/checkpoint_schedules/mixed.py
@@ -1,24 +1,11 @@
-from .schedule import (
- CheckpointSchedule, Clear, Configure, Forward, Reverse, Read, Write,
- EndForward, EndReverse)
+from checkpoint_schedules import StorageType
+from checkpoint_schedules import (
+ MixedCheckpointSchedule as _MixedCheckpointSchedule)
+
+from .translation import translation
import enum
import functools
-import numpy as np
-import warnings
-
-try:
- import numba
- from numba import njit
-except ImportError:
- numba = None
-
- def njit(fn):
- @functools.wraps(fn)
- def wrapped_fn(*args, **kwargs):
- return fn(*args, **kwargs)
- return wrapped_fn
-
__all__ = \
[
@@ -102,112 +89,7 @@ def mixed_step_memoization(n, s):
return m
-_NONE = int(StepType.NONE)
-_FORWARD = int(StepType.FORWARD)
-_FORWARD_REVERSE = int(StepType.FORWARD_REVERSE)
-_WRITE_DATA = int(StepType.WRITE_DATA)
-_WRITE_ICS = int(StepType.WRITE_ICS)
-
-
-@njit
-def mixed_steps_tabulation(n, s):
- schedule = np.zeros((n + 1, s + 1, 3), dtype=np.int_)
- schedule[:, :, 0] = _NONE
- schedule[:, :, 1] = 0
- schedule[:, :, 2] = -1
-
- for s_i in range(s + 1):
- schedule[1, s_i, :] = (_FORWARD_REVERSE, 1, 1)
- for s_i in range(1, s + 1):
- for n_i in range(2, n + 1):
- if n_i <= s_i + 1:
- schedule[n_i, s_i, :] = (_WRITE_DATA, 1, n_i)
- elif s_i == 1:
- schedule[n_i, s_i, :] = (_WRITE_ICS, n_i - 1, n_i * (n_i + 1) // 2 - 1) # noqa: E501
- else:
- for i in range(2, n_i):
- assert schedule[i, s_i, 2] > 0
- assert schedule[n_i - i, s_i - 1, 2] > 0
- m1 = (
- i
- + schedule[i, s_i, 2]
- + schedule[n_i - i, s_i - 1, 2])
- if schedule[n_i, s_i, 2] < 0 or m1 <= schedule[n_i, s_i, 2]: # noqa: E501
- schedule[n_i, s_i, :] = (_WRITE_ICS, i, m1)
- if schedule[n_i, s_i, 2] < 0:
- raise RuntimeError("Failed to determine total number of "
- "steps")
- assert schedule[n_i - 1, s_i - 1, 2] > 0
- m1 = 1 + schedule[n_i - 1, s_i - 1, 2]
- if m1 <= schedule[n_i, s_i, 2]:
- schedule[n_i, s_i, :] = (_WRITE_DATA, 1, m1)
- return schedule
-
-
-def cache_step_0(fn):
- _cache = {}
-
- @functools.wraps(fn)
- def wrapped_fn(n, s):
- # Avoid some cache misses
- s = min(s, n - 2)
- if (n, s) not in _cache:
- _cache[(n, s)] = fn(n, s)
- return _cache[(n, s)]
-
- return wrapped_fn
-
-
-@cache_step_0
-def mixed_step_memoization_0(n, s):
- if s < 0:
- raise ValueError("Invalid number of snapshots")
- if n < s + 2:
- raise ValueError("Invalid number of steps")
-
- if s == 0:
- return (StepType.FORWARD_REVERSE, n, n * (n + 1) // 2 - 1)
- else:
- m = None
- for i in range(1, n):
- m1 = (
- i
- + mixed_step_memoization(i, s + 1)[2]
- + mixed_step_memoization(n - i, s)[2])
- if m is None or m1 <= m[2]:
- m = (StepType.FORWARD, i, m1)
- if m is None:
- raise RuntimeError("Failed to determine total number of steps")
- return m
-
-
-@njit
-def mixed_steps_tabulation_0(n, s, schedule):
- schedule_0 = np.zeros((n + 1, s + 1, 3), dtype=np.int_)
- schedule_0[:, :, 0] = _NONE
- schedule_0[:, :, 1] = 0
- schedule_0[:, :, 2] = -1
-
- for n_i in range(2, n + 1):
- schedule_0[n_i, 0, :] = (_FORWARD_REVERSE, n_i, n_i * (n_i + 1) // 2 - 1) # noqa: E501
- for s_i in range(1, s):
- for n_i in range(s_i + 2, n + 1):
- for i in range(1, n_i):
- assert schedule[i, s_i + 1, 2] > 0
- assert schedule[n_i - i, s_i, 2] > 0
- m1 = (
- i
- + schedule[i, s_i + 1, 2]
- + schedule[n_i - i, s_i, 2])
- if schedule_0[n_i, s_i, 2] < 0 or m1 <= schedule_0[n_i, s_i, 2]: # noqa: E501
- schedule_0[n_i, s_i, :] = (_FORWARD, i, m1)
- if schedule_0[n_i, s_i, 2] < 0:
- raise RuntimeError("Failed to determine total number of "
- "steps")
- return schedule_0
-
-
-class MixedCheckpointSchedule(CheckpointSchedule):
+class MixedCheckpointSchedule(translation(_MixedCheckpointSchedule)):
"""A checkpointing schedule which mixes storage of forward restart data and
non-linear dependency data in checkpointing units. Assumes that the data
required to restart the forward has the same size as the data required to
@@ -228,154 +110,7 @@ class MixedCheckpointSchedule(CheckpointSchedule):
"""
def __init__(self, max_n, snapshots, *, storage="disk"):
- if snapshots < min(1, max_n - 1):
- raise ValueError("Invalid number of snapshots")
- if storage not in {"RAM", "disk"}:
- raise ValueError("Invalid storage")
-
- super().__init__(max_n)
- self._exhausted = False
- self._snapshots = min(snapshots, max_n - 1)
- self._storage = storage
-
- def iter(self):
- snapshot_n = set()
- snapshots = []
-
- if self._max_n is None:
- raise RuntimeError("Invalid checkpointing state")
-
- if numba is None:
- warnings.warn("Numba not available -- using memoization",
- RuntimeWarning)
- else:
- schedule = mixed_steps_tabulation(self._max_n, self._snapshots)
- schedule_0 = mixed_steps_tabulation_0(self._max_n, self._snapshots, schedule) # noqa: E501
-
- step_type = StepType.NONE
- while True:
- while self._n < self._max_n - self._r:
- n0 = self._n
- if n0 in snapshot_n:
- # n0 checkpoint exists
- if numba is None:
- step_type, n1, _ = mixed_step_memoization_0(
- self._max_n - self._r - n0,
- self._snapshots - len(snapshots))
- else:
- step_type, n1, _ = schedule_0[
- self._max_n - self._r - n0,
- self._snapshots - len(snapshots)]
- else:
- # n0 checkpoint does not exist
- if numba is None:
- step_type, n1, _ = mixed_step_memoization(
- self._max_n - self._r - n0,
- self._snapshots - len(snapshots))
- else:
- step_type, n1, _ = schedule[
- self._max_n - self._r - n0,
- self._snapshots - len(snapshots)]
- n1 += n0
-
- if step_type == StepType.FORWARD_REVERSE:
- if n1 > n0 + 1:
- yield Configure(False, False)
- self._n = n1 - 1
- yield Forward(n0, n1 - 1)
- yield Clear(True, True)
- elif n1 <= n0:
- raise RuntimeError("Invalid step")
- yield Configure(False, True)
- self._n += 1
- yield Forward(n1 - 1, n1)
- elif step_type == StepType.FORWARD:
- if n1 <= n0:
- raise RuntimeError("Invalid step")
- yield Configure(False, False)
- self._n = n1
- yield Forward(n0, n1)
- yield Clear(True, True)
- elif step_type == StepType.WRITE_DATA:
- if n1 != n0 + 1:
- raise RuntimeError("Invalid step")
- yield Configure(False, True)
- self._n = n1
- yield Forward(n0, n1)
- if n0 in snapshot_n:
- raise RuntimeError("Invalid checkpointing state")
- elif len(snapshots) > self._snapshots - 1:
- raise RuntimeError("Invalid checkpointing state")
- snapshot_n.add(n0)
- snapshots.append((StepType.READ_DATA, n0))
- yield Write(n0, self._storage)
- yield Clear(True, True)
- elif step_type == StepType.WRITE_ICS:
- if n1 <= n0 + 1:
- raise ValueError("Invalid step")
- yield Configure(True, False)
- self._n = n1
- yield Forward(n0, n1)
- if n0 in snapshot_n:
- raise RuntimeError("Invalid checkpointing state")
- elif len(snapshots) > self._snapshots - 1:
- raise RuntimeError("Invalid checkpointing state")
- snapshot_n.add(n0)
- snapshots.append((StepType.READ_ICS, n0))
- yield Write(n0, self._storage)
- yield Clear(True, True)
- else:
- raise RuntimeError("Unexpected step type")
- if self._n != self._max_n - self._r:
- raise RuntimeError("Invalid checkpointing state")
- if step_type not in (StepType.FORWARD_REVERSE, StepType.READ_DATA):
- raise RuntimeError("Invalid checkpointing state")
-
- if self._r == 0:
- yield EndForward()
-
- self._r += 1
- yield Reverse(self._max_n - self._r + 1, self._max_n - self._r)
- yield Clear(True, True)
-
- if self._r == self._max_n:
- break
-
- step_type, cp_n = snapshots[-1]
-
- # Delete if we have (possibly after deleting this checkpoint)
- # enough storage left to store all non-linear dependency data
- cp_delete = (cp_n >= (self._max_n - self._r - 1
- - (self._snapshots - len(snapshots) + 1)))
- if cp_delete:
- snapshot_n.remove(cp_n)
- snapshots.pop()
-
- self._n = cp_n
- if step_type == StepType.READ_DATA:
- # Non-linear dependency data checkpoint
- if not cp_delete:
- # We cannot advance from a loaded non-linear dependency
- # checkpoint, and so we expect to use it immediately
- raise RuntimeError("Invalid checkpointing state")
- # Note that we cannot in general restart the forward here
- self._n += 1
- elif step_type != StepType.READ_ICS:
- raise RuntimeError("Invalid checkpointing state")
- yield Read(cp_n, self._storage, cp_delete)
- if step_type == StepType.READ_ICS:
- yield Clear(True, True)
-
- if len(snapshot_n) > 0 or len(snapshots) > 0:
- raise RuntimeError("Invalid checkpointing state")
-
- self._exhausted = True
- yield EndReverse(True)
-
- @property
- def is_exhausted(self):
- return self._exhausted
-
- @property
- def uses_disk_storage(self):
- return self._max_n > 1 and self._storage == "disk"
+ super().__init__(
+ max_n, snapshots,
+ storage={"RAM": StorageType.RAM,
+ "disk": StorageType.DISK}[storage])
diff --git a/tlm_adjoint/checkpoint_schedules/none.py b/tlm_adjoint/checkpoint_schedules/none.py
index 64287b4a6..bc225bf8b 100644
--- a/tlm_adjoint/checkpoint_schedules/none.py
+++ b/tlm_adjoint/checkpoint_schedules/none.py
@@ -1,6 +1,7 @@
-from .schedule import CheckpointSchedule, Configure, Forward, EndForward
+from checkpoint_schedules import (
+ NoneCheckpointSchedule as _NoneCheckpointSchedule)
-import sys
+from .translation import translation
__all__ = \
[
@@ -8,38 +9,9 @@
]
-class NoneCheckpointSchedule(CheckpointSchedule):
+class NoneCheckpointSchedule(translation(_NoneCheckpointSchedule)):
"""A checkpointing schedule for the case where no adjoint calculation is
performed.
Online, zero adjoint calculations permitted.
"""
-
- def __init__(self):
- super().__init__()
- self._exhausted = False
-
- def iter(self):
- # Forward
-
- if self._max_n is not None:
- # Unexpected finalize
- raise RuntimeError("Invalid checkpointing state")
- yield Configure(False, False)
-
- while self._max_n is None:
- n0 = self._n
- n1 = n0 + sys.maxsize
- self._n = n1
- yield Forward(n0, n1)
-
- self._exhausted = True
- yield EndForward()
-
- @property
- def is_exhausted(self):
- return self._exhausted
-
- @property
- def uses_disk_storage(self):
- return False
diff --git a/tlm_adjoint/checkpoint_schedules/translation.py b/tlm_adjoint/checkpoint_schedules/translation.py
new file mode 100644
index 000000000..0c1999e66
--- /dev/null
+++ b/tlm_adjoint/checkpoint_schedules/translation.py
@@ -0,0 +1,148 @@
+from checkpoint_schedules import StorageType
+from checkpoint_schedules import (
+ Forward as _Forward, Reverse as _Reverse, Copy as _Copy, Move as _Move,
+ EndForward as _EndForward, EndReverse as _EndReverse)
+
+from .schedule import (
+ CheckpointSchedule, Configure, Clear, Forward, Reverse, Read, Write,
+ EndForward, EndReverse)
+
+import functools
+
+__all__ = \
+ [
+ ]
+
+
+def translation(cls):
+ class Translation(CheckpointSchedule):
+ def __init__(self, *args, **kwargs):
+ self._cp_schedule = cls(*args, **kwargs)
+ super().__init__(self._cp_schedule.max_n)
+ self._is_exhausted = self._cp_schedule.is_exhausted
+
+ def iter(self):
+ # Used to ensure that we do not finalize the wrapped scheduler
+ # while yielding actions associated with a single wrapped action.
+ # Prevents multiple finalization of the wrapped schedule.
+ def locked(fn):
+ @functools.wraps(fn)
+ def wrapped_fn(cp_action):
+ max_n = self._max_n
+ try:
+ yield from fn(cp_action)
+ finally:
+ if self._max_n != max_n:
+ self._cp_schedule.finalize(self._max_n)
+ return wrapped_fn
+
+ @functools.singledispatch
+ @locked
+ def action(cp_action):
+ raise TypeError(f"Unexpected action type: {type(cp_action)}")
+ yield None
+
+ @action.register(_Forward)
+ @locked
+ def action_forward(cp_action):
+ yield Clear(True, True)
+ yield Configure(cp_action.write_ics, cp_action.write_adj_deps)
+ self._n = self._cp_schedule.n
+ yield Forward(cp_action.n0, cp_action.n1)
+ if cp_action.storage not in {StorageType.NONE,
+ StorageType.WORK}:
+ yield Write(cp_action.n0,
+ {StorageType.RAM: "RAM",
+ StorageType.DISK: "disk"}[cp_action.storage])
+ yield Clear(True, False)
+
+ @action.register(_Reverse)
+ @locked
+ def action_reverse(cp_action):
+ if self._max_n is None:
+ raise RuntimeError("Invalid checkpointing state")
+ self._r = self._cp_schedule.r
+ yield Reverse(cp_action.n1, cp_action.n0)
+
+ @action.register(_Copy)
+ @locked
+ def action_copy(cp_action):
+ if cp_action.to_storage == StorageType.NONE:
+ pass
+ elif cp_action.to_storage in {StorageType.RAM, StorageType.DISK}: # noqa: E501
+ yield Clear(True, True)
+ self._n = self._cp_schedule.n
+ yield Read(cp_action.n,
+ {StorageType.RAM: "RAM",
+ StorageType.DISK: "disk"}[cp_action.from_storage], # noqa: E501
+ False)
+ yield Write(cp_action.n0,
+ {StorageType.RAM: "RAM",
+ StorageType.DISK: "disk"}[cp_action.to_storage]) # noqa: E501
+ yield Clear(True, True)
+ elif cp_action.to_storage == StorageType.WORK:
+ yield Clear(True, True)
+ self._n = self._cp_schedule.n
+ yield Read(cp_action.n,
+ {StorageType.RAM: "RAM",
+ StorageType.DISK: "disk"}[cp_action.from_storage], # noqa: E501
+ False)
+ else:
+ raise ValueError(f"Unexpected storage type: "
+ f"{cp_action.to_storage}")
+
+ @action.register(_Move)
+ @locked
+ def action_move(cp_action):
+ if cp_action.to_storage == StorageType.NONE:
+ pass
+ elif cp_action.to_storage in {StorageType.RAM, StorageType.DISK}: # noqa: E501
+ yield Clear(True, True)
+ self._n = self._cp_schedule.n
+ yield Read(cp_action.n,
+ {StorageType.RAM: "RAM",
+ StorageType.DISK: "disk"}[cp_action.from_storage], # noqa: E501
+ True)
+ yield Write(cp_action.n0,
+ {StorageType.RAM: "RAM",
+ StorageType.DISK: "disk"}[cp_action.to_storage]) # noqa: E501
+ yield Clear(True, True)
+ elif cp_action.to_storage == StorageType.WORK:
+ yield Clear(True, True)
+ self._n = self._cp_schedule.n
+ yield Read(cp_action.n,
+ {StorageType.RAM: "RAM",
+ StorageType.DISK: "disk"}[cp_action.from_storage], # noqa: E501
+ True)
+ else:
+ raise ValueError(f"Unexpected storage type: "
+ f"{cp_action.to_storage}")
+
+ @action.register(_EndForward)
+ @locked
+ def action_end_forward(cp_action):
+ self._is_exhausted = self._cp_schedule.is_exhausted
+ yield EndForward()
+
+ @action.register(_EndReverse)
+ @locked
+ def action_end_reverse(cp_action):
+ if self._cp_schedule.is_exhausted:
+ yield Clear(True, True)
+ self._r = self._cp_schedule.r
+ self._is_exhausted = self._cp_schedule.is_exhausted
+ yield EndReverse(self._cp_schedule.is_exhausted)
+
+ yield Clear(True, True)
+ while not self._cp_schedule.is_exhausted:
+ yield from action(next(self._cp_schedule))
+
+ @property
+ def is_exhausted(self):
+ return self._is_exhausted
+
+ @property
+ def uses_disk_storage(self):
+ return self._cp_schedule.uses_storage_type(StorageType.DISK)
+
+ return Translation