Skip to content

Commit

Permalink
Add ConstrainedTractableFlowSystem
Browse files Browse the repository at this point in the history
Defines expected interface for ConstrainedLeapfrogIntegrator
  • Loading branch information
matt-graham committed Aug 9, 2023
1 parent 899e061 commit fc6a954
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 117 deletions.
12 changes: 6 additions & 6 deletions src/mici/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
if TYPE_CHECKING:
from typing import Any, Callable, Optional, Sequence
from mici.states import ChainState
from mici.systems import System, TractableFlowSystem
from mici.systems import ConstrainedTractableFlowSystem, System, TractableFlowSystem
from mici.types import NormFunction


Expand Down Expand Up @@ -291,7 +291,7 @@ def __init__(self, system: TractableFlowSystem, step_size: Optional[float] = Non
step size adapter will be used to set the step size before calling the
:py:meth:`step` method.
"""
a_0 = (3 - 3 ** 0.5) / 6
a_0 = (3 - 3**0.5) / 6
super().__init__(system, (a_0,), step_size, True)


Expand Down Expand Up @@ -758,7 +758,7 @@ class ConstrainedLeapfrogIntegrator(TractableFlowIntegrator):
.. math::
c((\Phi_2(t) \circ \Pi(\lambda)(q, p))_1)
c((\Phi_2(t) \circ \Pi(\lambda)(q, p))_1)
= c((\Phi_2(t)(q, p + \partial c(q)^T \lambda))_1) = 0,
i.e. solving for the values of the Lagrange multipliers such that the position
Expand All @@ -768,7 +768,7 @@ class ConstrainedLeapfrogIntegrator(TractableFlowIntegrator):
\circ \Pi(\lambda)` is then projected in to the cotangent space to compute the final
state pair, with this projection step as noted above typically having an analytic
solution.
For more details see Reich (1996) and section 7.5.1 in Leimkuhler and Reich (2004).
The overall second-order integrator is then defined as the symmetric composition
Expand Down Expand Up @@ -823,7 +823,7 @@ class ConstrainedLeapfrogIntegrator(TractableFlowIntegrator):

def __init__(
self,
system: System,
system: ConstrainedTractableFlowSystem,
step_size: Optional[float] = None,
n_inner_step: int = 1,
reverse_check_tol: float = 2e-8,
Expand All @@ -833,7 +833,7 @@ def __init__(
):
"""
Args:
system: Hamiltonian system to integrate the dynamics of.
system: Hamiltonian system to integrate the constrained dynamics of.
step_size: Integrator time step. If set to :code:`None` it is assumed that a
step size adapter will be used to set the step size before calling the
:py:meth:`step` method.
Expand Down
35 changes: 23 additions & 12 deletions src/mici/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@

if TYPE_CHECKING:
from mici.states import ChainState
from mici.systems import System
from mici.systems import (
ConstrainedEuclideanMetricSystem,
ConstrainedTractableFlowSystem,
)
from mici.types import ScalarFunction, ArrayFunction, ArrayLike


def euclidean_norm(vct):
"""Calculate the Euclidean (L-2) norm of a vector."""
return (vct ** 2).sum() ** 0.5
return (vct**2).sum() ** 0.5


def maximum_norm(vct):
Expand Down Expand Up @@ -117,7 +120,7 @@ def solve_fixed_point_steffensen(
:code:`norm(func(x) - x) < convergence_tol`.
Raises:
mici.errors.ConvergenceError: If solver does not converge within
mici.errors.ConvergenceError: If solver does not converge within
:code:`max_iters` iterations, diverges or encounters a :py:exc:`ValueError`
during the iteration.
"""
Expand Down Expand Up @@ -162,15 +165,15 @@ class ProjectionSolver(Protocol):
manifold defined by the zero level set of a constraint function :math:`c`, with
:math:`\Phi_{2,1}` the flow map for the :math:`h_2` Hamiltonian component for the
system restricted to the position component output. The map :math:`\Phi_{2,1}` is
assumed to be linear in its second (momentum) argument.
assumed to be linear in its second (momentum) argument.
"""

def __call__(
self,
state: ChainState,
state_prev: ChainState,
time_step: float,
system: System,
system: ConstrainedTractableFlowSystem,
**kwargs,
) -> ChainState:
"""Solve for projection on to manifold step.
Expand All @@ -179,7 +182,7 @@ def __call__(
state: Current chain state after unconstrained step.
state_prev: Previous chain state on manifold.
time_step: Integrator time step for unconstrained step.
system: Hamiltonian system dynamics are being simulated for.
system: Hamiltonian system constrained dynamics are being simulated for.
Returns:
Chain state after projection on to manifold.
Expand All @@ -191,7 +194,7 @@ def solve_projection_onto_manifold_quasi_newton(
state: ChainState,
state_prev: ChainState,
time_step: float,
system: System,
system: ConstrainedEuclideanMetricSystem,
constraint_tol: float = 1e-9,
position_tol: float = 1e-8,
divergence_tol: float = 1e10,
Expand Down Expand Up @@ -302,7 +305,9 @@ def solve_projection_onto_manifold_quasi_newton(
mu = np.zeros_like(state.pos)
jacob_constr_prev = system.jacob_constr(state_prev)
# Use absolute value of dt and adjust for sign of dt in mom update below
dh2_flow_pos_dmom, dh2_flow_mom_dmom = system.dh2_flow_dmom(abs(time_step))
dh2_flow_pos_dmom, dh2_flow_mom_dmom = system.dh2_flow_dmom(
state_prev, abs(time_step)
)
inv_jacob_constr_inner_product = system.jacob_constr_inner_product(
jacob_constr_prev, dh2_flow_pos_dmom
).inv
Expand Down Expand Up @@ -338,7 +343,7 @@ def solve_projection_onto_manifold_newton(
state: ChainState,
state_prev: ChainState,
time_step: float,
system: System,
system: ConstrainedEuclideanMetricSystem,
constraint_tol: float = 1e-9,
position_tol: float = 1e-8,
divergence_tol: float = 1e10,
Expand Down Expand Up @@ -420,7 +425,9 @@ def solve_projection_onto_manifold_newton(
mu = np.zeros_like(state.pos)
jacob_constr_prev = system.jacob_constr(state_prev)
# Use absolute value of dt and adjust for sign of dt in mom update below
dh2_flow_pos_dmom, dh2_flow_mom_dmom = system.dh2_flow_dmom(abs(time_step))
dh2_flow_pos_dmom, dh2_flow_mom_dmom = system.dh2_flow_dmom(
state_prev, abs(time_step)
)
for i in range(max_iters):
try:
jacob_constr = system.jacob_constr(state)
Expand Down Expand Up @@ -459,7 +466,7 @@ def solve_projection_onto_manifold_newton_with_line_search(
state: ChainState,
state_prev: ChainState,
time_step: float,
system: System,
system: ConstrainedEuclideanMetricSystem,
constraint_tol: float = 1e-9,
position_tol: float = 1e-8,
divergence_tol: float = 1e10,
Expand Down Expand Up @@ -549,7 +556,11 @@ def solve_projection_onto_manifold_newton_with_line_search(
mu = np.zeros_like(state.pos)
jacob_constr_prev = system.jacob_constr(state_prev)
# Use absolute value of dt and adjust for sign of dt in mom update below
dh2_flow_pos_dmom, dh2_flow_mom_dmom = system.dh2_flow_dmom(abs(time_step))
dh2_flow_pos_dmom, dh2_flow_mom_dmom = system.dh2_flow_dmom(
state_prev, abs(time_step)
)
# Initialize with dummy values to avoid undefined name linter errors
delta_pos, step_size = None, None
for i in range(max_iters):
try:
jacob_constr = system.jacob_constr(state)
Expand Down
Loading

0 comments on commit fc6a954

Please sign in to comment.