Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable extending NonEquilibriumCyclingProtocols #44

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 128 additions & 13 deletions feflow/protocols/nonequilibrium_cycling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from openff.units import unit
from openff.units.openmm import to_openmm, from_openmm

from ..utils.data import serialize, deserialize
from ..utils.data import (
serialize,
deserialize,
serialize_and_compress,
decompress_and_deserialize,
)

# Specific instance of logger for this module
# logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -138,6 +143,34 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs):
from openfe.protocols.openmm_rfe import _rfe_utils
from feflow.utils.hybrid_topology import HybridTopologyFactory

if extends_data := self.inputs.get("extends_data"):

def _write_xml(data, filename):
openmm_object = decompress_and_deserialize(data)
serialize(openmm_object, filename)
return filename

for cycle in range(settings.num_cycles):
cycle = str(cycle)
system_outfile = ctx.shared / f"system_{cycle}.xml.bz2"
state_outfile = ctx.shared / f"state_{cycle}.xml.bz2"
integrator_outfile = ctx.shared / f"integrator_{cycle}.xml.bz2"

extends_data["systems"][cycle] = _write_xml(
extends_data["systems"][cycle],
system_outfile,
)
extends_data["states"][cycle] = _write_xml(
extends_data["states"][cycle],
state_outfile,
)
extends_data["integrators"][cycle] = _write_xml(
extends_data["integrators"][cycle],
integrator_outfile,
)

return extends_data

# Check compatibility between states (same receptor and solvent)
self._check_states_compatibility(state_a, state_b)

Expand Down Expand Up @@ -357,10 +390,18 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs):
# Explicit cleanup for GPU resources
del context, integrator

systems = {}
states = {}
integrators = {}
for cycle_name in map(str, range(settings.num_cycles)):
systems[cycle_name] = system_outfile
states[cycle_name] = state_outfile
integrators[cycle_name] = integrator_outfile

return {
"system": system_outfile,
"state": state_outfile,
"integrator": integrator_outfile,
"systems": systems,
"states": states,
"integrators": integrators,
"phase": phase,
"initial_atom_indices": hybrid_factory.initial_atom_indices,
"final_atom_indices": hybrid_factory.final_atom_indices,
Expand Down Expand Up @@ -449,9 +490,9 @@ def _execute(self, ctx, *, setup, settings, **inputs):
file_logger.addHandler(file_handler)

# Get state, system, and integrator from setup unit
system = deserialize(setup.outputs["system"])
state = deserialize(setup.outputs["state"])
integrator = deserialize(setup.outputs["integrator"])
system = deserialize(setup.outputs["systems"][self.name])
state = deserialize(setup.outputs["states"][self.name])
integrator = deserialize(setup.outputs["integrators"][self.name])
PeriodicNonequilibriumIntegrator.restore_interface(integrator)

# Get atom indices for either end of the hybrid topology
Expand Down Expand Up @@ -712,7 +753,20 @@ def _execute(self, ctx, *, setup, settings, **inputs):
"reverse_neq_final": reverse_neq_new_path,
}
finally:
compressed_state = serialize_and_compress(
context.getState(getPositions=True),
)

compressed_system = serialize_and_compress(
context.getSystem(),
)

compressed_integrator = serialize_and_compress(
context.getIntegrator(),
)

# Explicit cleanup for GPU resources

del context, integrator

return {
Expand All @@ -721,6 +775,9 @@ def _execute(self, ctx, *, setup, settings, **inputs):
"trajectory_paths": trajectory_paths,
"log": output_log_path,
"timing_info": timing_info,
"system": compressed_system,
"state": compressed_state,
"integrator": compressed_integrator,
}


Expand Down Expand Up @@ -919,8 +976,67 @@ def _create(
# Handle parameters
if mapping is None:
raise ValueError("`mapping` is required for this Protocol")
if extends:
raise NotImplementedError("Can't extend simulations yet")

extends_data = {}
if isinstance(extends, ProtocolDAGResult):

if not extends.ok():
raise ValueError("Cannot extend protocols that failed")

setup = extends.protocol_units[0]
simulations = extends.protocol_units[1:-1]

r_setup = extends.protocol_unit_results[0]
r_simulations = extends.protocol_unit_results[1:-1]

# confirm consistency
original_state_a = setup.inputs["state_a"].key
original_state_b = setup.inputs["state_b"].key
original_mapping = setup.inputs["mapping"]

if original_state_a != stateA.key:
raise ValueError(
"'stateA' key is not the same as the key provided by the 'extends' ProtocolDAGResult."
)

if original_state_b != stateB.key:
raise ValueError(
"'stateB' key is not the same as the key provided by the 'extends' ProtocolDAGResult."
)

if mapping is not None:
if original_mapping != mapping:
raise ValueError(
"'mapping' is not consistent with the mapping provided by the 'extends' ProtocolDAGResult."
)

# TODO: are there instances where this is too strict?
if setup.inputs["settings"] != self.settings:
raise ValueError(
"protocol settings are not consistent with those present in the SetupUnit of the 'extends' ProtocolDAGResult."
)
Comment on lines +1013 to +1017
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good check to have, and is not too strict at all. The philosophy behind a Protocol with a given Settings object is that those Settings are immutable once given to the Protocol object, and so calling Protocol.create(..., extends=<protocol_dag_result>) should feature identical settings between the calling Protocol and the ProtocolDAGResult we wish to extend from.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just chiming in here - there are a few cases where this is going to be "not ideal" (mostly QoL but possibly could be frustrating).

So one example here could be a case where you extend but you switch compute platform - the results should be the same, but the settings will be different. Similarly the trajectory write frequency (you can imagine a case where someone decides that it's just not that necessary to have the trajectory after extending), etc...

It's of course not a blocker right now, but a "smarter" equality might be necessary (e.g. OpenFreeEnergy/gufe#329).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point, and think it makes sense to have a thermodynamically_equals method or function for Protocol settings objects for this purpose.

Systems like alchemiscale may limit the way users can do extends in the sense that Tasks may only extend other Tasks from the same Transformation, but the Protocol system on its own isn't as strict, since the concept of a Transformation sits outside of it.


else:
mapping = original_mapping

systems = {}
states = {}
integrators = {}

for r_simulation, simulation in zip(r_simulations, simulations):
sim_name = simulation.name
systems[sim_name] = r_simulation.outputs["system"]
states[sim_name] = r_simulation.outputs["state"]
integrators[sim_name] = r_simulation.outputs["integrator"]

extends_data = dict(
systems=systems,
states=states,
integrators=integrators,
phase=r_setup.outputs["phase"],
initial_atom_indices=r_setup.outputs["initial_atom_indices"],
final_atom_indices=r_setup.outputs["final_atom_indices"],
)

# inputs to `ProtocolUnit.__init__` should either be `Gufe` objects
# or JSON-serializable objects
Expand All @@ -932,13 +1048,12 @@ def _create(
mapping=mapping,
settings=self.settings,
name="setup",
extends_data=extends_data,
)

simulations = [
self._simulation_unit(
setup=setup, settings=self.settings, name=f"{replicate}"
)
for replicate in range(num_cycles)
self._simulation_unit(setup=setup, settings=self.settings, name=f"{cycle}")
for cycle in range(num_cycles)
]

end = ResultUnit(name="result", simulations=simulations)
Expand Down
1 change: 1 addition & 0 deletions feflow/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def short_settings():
settings.thermo_settings.temperature = 300 * unit.kelvin
settings.integrator_settings.equilibrium_steps = 250
settings.integrator_settings.nonequilibrium_steps = 250
settings.num_cycles = 1
settings.work_save_frequency = 50
settings.traj_save_frequency = 250

Expand Down
86 changes: 86 additions & 0 deletions feflow/tests/test_nonequilibrium_cycling.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,92 @@ def test_terminal_units(self, protocol_dag_result):
assert isinstance(finals[0], ProtocolUnitResult)
assert finals[0].name == "result"

@pytest.mark.parametrize(
"protocol",
[
"protocol_short",
"protocol_short_multiple_cycles",
],
)
def test_pdr_extend(
self,
protocol,
benzene_vacuum_system,
toluene_vacuum_system,
mapping_benzene_toluene,
tmpdir,
request,
):
protocol = request.getfixturevalue(protocol)
dag = protocol.create(
stateA=benzene_vacuum_system,
stateB=toluene_vacuum_system,
name="Short vacuum transformation",
mapping=mapping_benzene_toluene,
)

with tmpdir.as_cwd():

base_path = Path("original")

shared = base_path / "shared"
shared.mkdir(parents=True)

scratch = base_path / "scratch"
scratch.mkdir(parents=True)

pdr: ProtocolDAGResult = execute_DAG(
dag, shared_basedir=shared, scratch_basedir=scratch
)

setup = pdr.protocol_units[0]
r_setup = pdr.protocol_unit_results[0]

assert setup.inputs["extends_data"] == {}

end_states = {}
for simulation, r_simulation in zip(
pdr.protocol_units[1:-1], pdr.protocol_unit_results[1:-1]
):
assert isinstance(r_simulation.outputs["system"], str)
assert isinstance(r_simulation.outputs["state"], str)
assert isinstance(r_simulation.outputs["integrator"], str)

end_states[simulation.name] = r_simulation.outputs["state"]

dag = protocol.create(
stateA=benzene_vacuum_system,
stateB=toluene_vacuum_system,
name="Short vacuum transformation, but extended",
mapping=mapping_benzene_toluene,
extends=ProtocolDAGResult.from_dict(pdr.to_dict()),
)

with tmpdir.as_cwd():

base_path = Path("extended")

shared = base_path / "shared"
shared.mkdir(parents=True)

scratch = base_path / "scratch"
scratch.mkdir(parents=True)
pdr: ProtocolDAGResult = execute_DAG(
dag, shared_basedir=shared, scratch_basedir=scratch
)

r_setup = pdr.protocol_unit_results[0]

assert r_setup.inputs["extends_data"] != {}

for cycle in range(protocol.settings.num_cycles):
cycle = str(cycle)
assert isinstance(r_setup.inputs["extends_data"]["systems"][cycle], str)
assert isinstance(r_setup.inputs["extends_data"]["states"][cycle], str)
assert isinstance(r_setup.inputs["extends_data"]["integrators"][cycle], str)

assert r_setup.inputs["extends_data"]["states"][cycle] == end_states[cycle]

# TODO: We probably need to find failure test cases as control
# def test_dag_execute_failure(self, protocol_dag_broken):
# protocol, dag, dagfailure = protocol_dag_broken
Expand Down
44 changes: 43 additions & 1 deletion feflow/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,48 @@
import os
import pathlib
import bz2
import base64

from openmm import XmlSerializer


def serialize_and_compress(item) -> str:
"""Serialize an OpenMM System, State, or Integrator and compress.

Parameters
----------
item : System, State, or Integrator
The OpenMM object to serialize and compress.

Returns
-------
b64string : str
The compressed serialized OpenMM object encoded in a Base64 string.
"""
serialized = XmlSerializer.serialize(item).encode()
compressed = bz2.compress(serialized)
b64string = base64.b64encode(compressed).decode("ascii")
return b64string


def decompress_and_deserialize(data: str):
"""Recover an OpenMM object from compression.

Parameters
----------
data : str
String containing a Base64 encoded bzip2 compressed XML serialization
of an OpenMM object.

Returns
-------
deserialized
The deserialized OpenMM object.
"""
compressed = base64.b64decode(data)
decompressed = bz2.decompress(compressed).decode("utf-8")
deserialized = XmlSerializer.deserialize(decompressed)
return deserialized


def serialize(item, filename: pathlib.Path):
Expand All @@ -13,7 +56,6 @@ def serialize(item, filename: pathlib.Path):
filename : str
The filename to serialize to
"""
from openmm import XmlSerializer

# Create parent directory if it doesn't exist
filename_basedir = filename.parent
Expand Down
Loading