Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DNM] Feature noneq cycling extends #46

Closed
wants to merge 9 commits into from
129 changes: 120 additions & 9 deletions feflow/protocols/nonequilibrium_cycling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from openff.units import unit
from openff.units.openmm import to_openmm, from_openmm

from ..utils.data import serialize, deserialize
from ..utils.data import (
serialize,
deserialize,
serialize_and_compress,
decompress_and_deserialize,
)

# Specific instance of logger for this module
# logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -133,6 +138,34 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs):
from openfe.protocols.openmm_rfe import _rfe_utils
from feflow.utils.hybrid_topology import HybridTopologyFactory

if extends_data := self.inputs.get("extends_data"):

def _write_xml(data, filename):
openmm_object = decompress_and_deserialize(data)
serialize(openmm_object, filename)
return filename

for replicate in range(settings.num_replicates):
replicate = str(replicate)
system_outfile = ctx.shared / f"system_{replicate}.xml.bz2"
state_outfile = ctx.shared / f"state_{replicate}.xml.bz2"
integrator_outfile = ctx.shared / f"integrator_{replicate}.xml.bz2"

extends_data["systems"][replicate] = _write_xml(
extends_data["systems"][replicate],
system_outfile,
)
extends_data["states"][replicate] = _write_xml(
extends_data["states"][replicate],
state_outfile,
)
extends_data["integrators"][replicate] = _write_xml(
extends_data["integrators"][replicate],
integrator_outfile,
)

return extends_data

# Check compatibility between states (same receptor and solvent)
self._check_states_compatibility(state_a, state_b)

Expand Down Expand Up @@ -342,10 +375,18 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs):
# Explicit cleanup for GPU resources
del context, integrator

systems = dict()
states = dict()
integrators = dict()
for replicate_name in map(str, range(settings.num_replicates)):
systems[replicate_name] = system_outfile
states[replicate_name] = state_outfile
integrators[replicate_name] = integrator_outfile

return {
"system": system_outfile,
"state": state_outfile,
"integrator": integrator_outfile,
"systems": systems,
"states": states,
"integrators": integrators,
"phase": phase,
"initial_atom_indices": hybrid_factory.initial_atom_indices,
"final_atom_indices": hybrid_factory.final_atom_indices,
Expand Down Expand Up @@ -434,9 +475,9 @@ def _execute(self, ctx, *, setup, settings, **inputs):
file_logger.addHandler(file_handler)

# Get state, system, and integrator from setup unit
system = deserialize(setup.outputs["system"])
state = deserialize(setup.outputs["state"])
integrator = deserialize(setup.outputs["integrator"])
system = deserialize(setup.outputs["systems"][self.name])
state = deserialize(setup.outputs["states"][self.name])
integrator = deserialize(setup.outputs["integrators"][self.name])
PeriodicNonequilibriumIntegrator.restore_interface(integrator)

# Get atom indices for either end of the hybrid topology
Expand Down Expand Up @@ -687,7 +728,20 @@ def _execute(self, ctx, *, setup, settings, **inputs):
"reverse_neq_final": reverse_neq_new_path,
}
finally:
compressed_state = serialize_and_compress(
context.getState(getPositions=True),
)

compressed_system = serialize_and_compress(
context.getSystem(),
)

compressed_integrator = serialize_and_compress(
context.getIntegrator(),
)

# Explicit cleanup for GPU resources

del context, integrator

return {
Expand All @@ -696,6 +750,9 @@ def _execute(self, ctx, *, setup, settings, **inputs):
"trajectory_paths": trajectory_paths,
"log": output_log_path,
"timing_info": timing_info,
"system": compressed_system,
"state": compressed_state,
"integrator": compressed_integrator,
}


Expand Down Expand Up @@ -890,10 +947,63 @@ def _create(
# Handle parameters
if mapping is None:
raise ValueError("`mapping` is required for this Protocol")

if "ligand" not in mapping:
raise ValueError("'ligand' must be specified in `mapping` dict")
if extends:
raise NotImplementedError("Can't extend simulations yet")

extends_data = {}
if isinstance(extends, ProtocolDAGResult):

if not extends.ok():
raise ValueError("Cannot extend protocols that failed")

setup = extends.protocol_units[0]
simulations = extends.protocol_units[1:-1]

r_setup = extends.protocol_unit_results[0]
r_simulations = extends.protocol_unit_results[1:-1]

# confirm consistency
original_state_a = setup.inputs["state_a"].key
original_state_b = setup.inputs["state_b"].key
original_mapping = setup.inputs["mapping"]

if original_state_a != stateA.key:
raise ValueError(
"'stateA' key is not the same as the key provided by the 'extends' ProtocolDAGResult."
)

if original_state_b != stateB.key:
raise ValueError(
"'stateB' key is not the same as the key provided by the 'extends' ProtocolDAGResult."
)

if mapping is not None:
if original_mapping != mapping:
raise ValueError(
"'mapping' is not consistent with the mapping provided by the 'extnds' ProtocolDAGResult."
)
else:
mapping = original_mapping

systems = {}
states = {}
integrators = {}

for r_simulation, simulation in zip(r_simulations, simulations):
sim_name = simulation.name
systems[sim_name] = r_simulation.outputs["system"]
states[sim_name] = r_simulation.outputs["state"]
integrators[sim_name] = r_simulation.outputs["integrator"]

extends_data = dict(
systems=systems,
states=states,
integrators=integrators,
phase=r_setup.outputs["phase"],
initial_atom_indices=r_setup.outputs["initial_atom_indices"],
final_atom_indices=r_setup.outputs["final_atom_indices"],
)

# inputs to `ProtocolUnit.__init__` should either be `Gufe` objects
# or JSON-serializable objects
Expand All @@ -905,6 +1015,7 @@ def _create(
mapping=mapping,
settings=self.settings,
name="setup",
extends_data=extends_data,
)

simulations = [
Expand Down
4 changes: 2 additions & 2 deletions feflow/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def short_settings():
settings = NonEquilibriumCyclingProtocol.default_settings()

settings.thermo_settings.temperature = 300 * unit.kelvin
settings.eq_steps = 25000
settings.neq_steps = 25000
settings.eq_steps = 1000
settings.neq_steps = 1000
settings.work_save_frequency = 50
settings.traj_save_frequency = 250
settings.platform = "CPU"
Expand Down
92 changes: 92 additions & 0 deletions feflow/tests/test_nonequilibrium_cycling.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,98 @@ def test_terminal_units(self, protocol_dag_result):
assert isinstance(finals[0], ProtocolUnitResult)
assert finals[0].name == "result"

@pytest.mark.parametrize(
"protocol",
[
"protocol_short",
"protocol_short_multiple_cycles",
],
)
def test_pdr_extend(
self,
protocol,
benzene_vacuum_system,
toluene_vacuum_system,
mapping_benzene_toluene,
tmpdir,
request,
):

protocol = request.getfixturevalue(protocol)
dag = protocol.create(
stateA=benzene_vacuum_system,
stateB=toluene_vacuum_system,
name="Short vacuum transformation",
mapping={"ligand": mapping_benzene_toluene},
)

with tmpdir.as_cwd():

base_path = Path("original")

shared = base_path / "shared"
shared.mkdir(parents=True)

scratch = base_path / "scratch"
scratch.mkdir(parents=True)

pdr: ProtocolDAGResult = execute_DAG(
dag, shared_basedir=shared, scratch_basedir=scratch
)

setup = pdr.protocol_units[0]
r_setup = pdr.protocol_unit_results[0]

assert setup.inputs["extends_data"] == {}

end_states = {}
for simulation, r_simulation in zip(
pdr.protocol_units[1:-1], pdr.protocol_unit_results[1:-1]
):
assert isinstance(r_simulation.outputs["system"], str)
assert isinstance(r_simulation.outputs["state"], str)
assert isinstance(r_simulation.outputs["integrator"], str)

end_states[simulation.name] = r_simulation.outputs["state"]

dag = protocol.create(
stateA=benzene_vacuum_system,
stateB=toluene_vacuum_system,
name="Short vacuum transformation, but extended",
mapping={"ligand": mapping_benzene_toluene},
extends=ProtocolDAGResult.from_dict(pdr.to_dict()),
)

with tmpdir.as_cwd():

base_path = Path("extended")

shared = base_path / "shared"
shared.mkdir(parents=True)

scratch = base_path / "scratch"
scratch.mkdir(parents=True)
pdr: ProtocolDAGResult = execute_DAG(
dag, shared_basedir=shared, scratch_basedir=scratch
)

r_setup = pdr.protocol_unit_results[0]

assert r_setup.inputs["extends_data"] != {}

for replicate in range(protocol.settings.num_replicates):
replicate = str(replicate)
assert isinstance(r_setup.inputs["extends_data"]["systems"][replicate], str)
assert isinstance(r_setup.inputs["extends_data"]["states"][replicate], str)
assert isinstance(
r_setup.inputs["extends_data"]["integrators"][replicate], str
)

assert (
r_setup.inputs["extends_data"]["states"][replicate]
== end_states[replicate]
)

def test_dag_execute_failure(self, protocol_dag_broken):
protocol, dag, dagfailure = protocol_dag_broken

Expand Down
44 changes: 43 additions & 1 deletion feflow/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,48 @@
import os
import pathlib
import bz2
import base64

from openmm import XmlSerializer


def serialize_and_compress(item) -> str:
"""Serialize an OpenMM System, State, or Integrator and compress.

Parameters
----------
item : System, State, or Integrator
The OpenMM object to serialize and compress.

Returns
-------
b64string : str
The compressed serialized OpenMM object encoded in a Base64 string.
"""
serialized = XmlSerializer.serialize(item).encode()
compressed = bz2.compress(serialized)
b64string = base64.b64encode(compressed).decode("ascii")
return b64string


def decompress_and_deserialize(data: str):
"""Recover an OpenMM object from compression.

Parameters
----------
data : str
String containing a Base64 encoded bzip2 compressed XML serialization
of an OpenMM object.

Returns
-------
deserialized
The deserialized OpenMM object.
"""
compressed = base64.b64decode(data)
decompressed = bz2.decompress(compressed).decode("utf-8")
deserialized = XmlSerializer.deserialize(decompressed)
return deserialized


def serialize(item, filename: pathlib.Path):
Expand All @@ -13,7 +56,6 @@ def serialize(item, filename: pathlib.Path):
filename : str
The filename to serialize to
"""
from openmm import XmlSerializer

# Create parent directory if it doesn't exist
filename_basedir = filename.parent
Expand Down
Loading