Skip to content

Commit

Permalink
Use optimal_steps_binomial from checkpoint_schedules
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmaddison committed Jun 26, 2024
1 parent 5311c3f commit f945395
Showing 1 changed file with 2 additions and 50 deletions.
52 changes: 2 additions & 50 deletions tlm_adjoint/checkpoint_schedules/binomial.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,18 @@
from checkpoint_schedules import StorageType
from checkpoint_schedules import (
MultistageCheckpointSchedule as _MultistageCheckpointSchedule,
TwoLevelCheckpointSchedule as _TwoLevelCheckpointSchedule)
TwoLevelCheckpointSchedule as _TwoLevelCheckpointSchedule,
optimal_steps_binomial as optimal_steps)

from .translation import translation

import functools

__all__ = \
[
"MultistageCheckpointSchedule",
"TwoLevelCheckpointSchedule"
]


def cache_step(fn):
_cache = {}

@functools.wraps(fn)
def wrapped_fn(n, s):
# Avoid some cache misses
s = min(s, n - 1)
if (n, s) not in _cache:
_cache[(n, s)] = fn(n, s)
return _cache[(n, s)]

return wrapped_fn


@cache_step
def optimal_extra_steps(n, s):
if n <= 0:
raise ValueError("Invalid number of steps")
if s < min(1, n - 1) or s > n - 1:
raise ValueError("Invalid number of snapshots")

if n == 1:
return 0
# Equation (2) of
# A. Griewank and A. Walther, "Algorithm 799: Revolve: An implementation
# of checkpointing for the reverse or adjoint mode of computational
# differentiation", ACM Transactions on Mathematical Software, 26(1), pp.
# 19--45, 2000
elif s == 1:
return n * (n - 1) // 2
else:
m = None
for i in range(1, n):
m1 = (i
+ optimal_extra_steps(i, s)
+ optimal_extra_steps(n - i, s - 1))
if m is None or m1 < m:
m = m1
if m is None:
raise RuntimeError("Failed to determine number of extra steps")
return m


def optimal_steps(n, s):
return n + optimal_extra_steps(n, s)


class MultistageCheckpointSchedule(translation(_MultistageCheckpointSchedule)):
"""A binomial checkpointing schedule using the approach described in
Expand Down

0 comments on commit f945395

Please sign in to comment.