diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index d33c2e9..9067b5a 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -25,7 +25,12 @@ from openff.units import unit from openff.units.openmm import to_openmm, from_openmm -from ..utils.data import serialize, deserialize +from ..utils.data import ( + serialize, + deserialize, + serialize_and_compress, + decompress_and_deserialize, +) # Specific instance of logger for this module # logger = logging.getLogger(__name__) @@ -138,6 +143,34 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): from openfe.protocols.openmm_rfe import _rfe_utils from feflow.utils.hybrid_topology import HybridTopologyFactory + if extends_data := self.inputs.get("extends_data"): + + def _write_xml(data, filename): + openmm_object = decompress_and_deserialize(data) + serialize(openmm_object, filename) + return filename + + for cycle in range(settings.num_cycles): + cycle = str(cycle) + system_outfile = ctx.shared / f"system_{cycle}.xml.bz2" + state_outfile = ctx.shared / f"state_{cycle}.xml.bz2" + integrator_outfile = ctx.shared / f"integrator_{cycle}.xml.bz2" + + extends_data["systems"][cycle] = _write_xml( + extends_data["systems"][cycle], + system_outfile, + ) + extends_data["states"][cycle] = _write_xml( + extends_data["states"][cycle], + state_outfile, + ) + extends_data["integrators"][cycle] = _write_xml( + extends_data["integrators"][cycle], + integrator_outfile, + ) + + return extends_data + # Check compatibility between states (same receptor and solvent) self._check_states_compatibility(state_a, state_b) @@ -357,10 +390,18 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): # Explicit cleanup for GPU resources del context, integrator + systems = {} + states = {} + integrators = {} + for cycle_name in map(str, range(settings.num_cycles)): + systems[cycle_name] = system_outfile + states[cycle_name] = state_outfile + integrators[cycle_name] = integrator_outfile + return { - "system": system_outfile, - "state": state_outfile, - "integrator": integrator_outfile, + "systems": systems, + "states": states, + "integrators": integrators, "phase": phase, "initial_atom_indices": hybrid_factory.initial_atom_indices, "final_atom_indices": hybrid_factory.final_atom_indices, @@ -449,9 +490,9 @@ def _execute(self, ctx, *, setup, settings, **inputs): file_logger.addHandler(file_handler) # Get state, system, and integrator from setup unit - system = deserialize(setup.outputs["system"]) - state = deserialize(setup.outputs["state"]) - integrator = deserialize(setup.outputs["integrator"]) + system = deserialize(setup.outputs["systems"][self.name]) + state = deserialize(setup.outputs["states"][self.name]) + integrator = deserialize(setup.outputs["integrators"][self.name]) PeriodicNonequilibriumIntegrator.restore_interface(integrator) # Get atom indices for either end of the hybrid topology @@ -712,7 +753,20 @@ def _execute(self, ctx, *, setup, settings, **inputs): "reverse_neq_final": reverse_neq_new_path, } finally: + compressed_state = serialize_and_compress( + context.getState(getPositions=True), + ) + + compressed_system = serialize_and_compress( + context.getSystem(), + ) + + compressed_integrator = serialize_and_compress( + context.getIntegrator(), + ) + # Explicit cleanup for GPU resources + del context, integrator return { @@ -721,6 +775,9 @@ def _execute(self, ctx, *, setup, settings, **inputs): "trajectory_paths": trajectory_paths, "log": output_log_path, "timing_info": timing_info, + "system": compressed_system, + "state": compressed_state, + "integrator": compressed_integrator, } @@ -919,8 +976,67 @@ def _create( # Handle parameters if mapping is None: raise ValueError("`mapping` is required for this Protocol") - if extends: - raise NotImplementedError("Can't extend simulations yet") + + extends_data = {} + if isinstance(extends, ProtocolDAGResult): + + if not extends.ok(): + raise ValueError("Cannot extend protocols that failed") + + setup = extends.protocol_units[0] + simulations = extends.protocol_units[1:-1] + + r_setup = extends.protocol_unit_results[0] + r_simulations = extends.protocol_unit_results[1:-1] + + # confirm consistency + original_state_a = setup.inputs["state_a"].key + original_state_b = setup.inputs["state_b"].key + original_mapping = setup.inputs["mapping"] + + if original_state_a != stateA.key: + raise ValueError( + "'stateA' key is not the same as the key provided by the 'extends' ProtocolDAGResult." + ) + + if original_state_b != stateB.key: + raise ValueError( + "'stateB' key is not the same as the key provided by the 'extends' ProtocolDAGResult." + ) + + if mapping is not None: + if original_mapping != mapping: + raise ValueError( + "'mapping' is not consistent with the mapping provided by the 'extends' ProtocolDAGResult." + ) + + # TODO: are there instances where this is too strict? + if setup.inputs["settings"] != self.settings: + raise ValueError( + "protocol settings are not consistent with those present in the SetupUnit of the 'extends' ProtocolDAGResult." + ) + + else: + mapping = original_mapping + + systems = {} + states = {} + integrators = {} + + for r_simulation, simulation in zip(r_simulations, simulations): + sim_name = simulation.name + systems[sim_name] = r_simulation.outputs["system"] + states[sim_name] = r_simulation.outputs["state"] + integrators[sim_name] = r_simulation.outputs["integrator"] + + extends_data = dict( + systems=systems, + states=states, + integrators=integrators, + phase=r_setup.outputs["phase"], + initial_atom_indices=r_setup.outputs["initial_atom_indices"], + final_atom_indices=r_setup.outputs["final_atom_indices"], + ) # inputs to `ProtocolUnit.__init__` should either be `Gufe` objects # or JSON-serializable objects @@ -932,13 +1048,12 @@ def _create( mapping=mapping, settings=self.settings, name="setup", + extends_data=extends_data, ) simulations = [ - self._simulation_unit( - setup=setup, settings=self.settings, name=f"{replicate}" - ) - for replicate in range(num_cycles) + self._simulation_unit(setup=setup, settings=self.settings, name=f"{cycle}") + for cycle in range(num_cycles) ] end = ResultUnit(name="result", simulations=simulations) diff --git a/feflow/tests/conftest.py b/feflow/tests/conftest.py index 9cdba24..8fd3e1e 100644 --- a/feflow/tests/conftest.py +++ b/feflow/tests/conftest.py @@ -76,6 +76,7 @@ def short_settings(): settings.thermo_settings.temperature = 300 * unit.kelvin settings.integrator_settings.equilibrium_steps = 250 settings.integrator_settings.nonequilibrium_steps = 250 + settings.num_cycles = 1 settings.work_save_frequency = 50 settings.traj_save_frequency = 250 diff --git a/feflow/tests/test_nonequilibrium_cycling.py b/feflow/tests/test_nonequilibrium_cycling.py index 22da6b1..d1b301b 100644 --- a/feflow/tests/test_nonequilibrium_cycling.py +++ b/feflow/tests/test_nonequilibrium_cycling.py @@ -100,6 +100,92 @@ def test_terminal_units(self, protocol_dag_result): assert isinstance(finals[0], ProtocolUnitResult) assert finals[0].name == "result" + @pytest.mark.parametrize( + "protocol", + [ + "protocol_short", + "protocol_short_multiple_cycles", + ], + ) + def test_pdr_extend( + self, + protocol, + benzene_vacuum_system, + toluene_vacuum_system, + mapping_benzene_toluene, + tmpdir, + request, + ): + protocol = request.getfixturevalue(protocol) + dag = protocol.create( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + name="Short vacuum transformation", + mapping=mapping_benzene_toluene, + ) + + with tmpdir.as_cwd(): + + base_path = Path("original") + + shared = base_path / "shared" + shared.mkdir(parents=True) + + scratch = base_path / "scratch" + scratch.mkdir(parents=True) + + pdr: ProtocolDAGResult = execute_DAG( + dag, shared_basedir=shared, scratch_basedir=scratch + ) + + setup = pdr.protocol_units[0] + r_setup = pdr.protocol_unit_results[0] + + assert setup.inputs["extends_data"] == {} + + end_states = {} + for simulation, r_simulation in zip( + pdr.protocol_units[1:-1], pdr.protocol_unit_results[1:-1] + ): + assert isinstance(r_simulation.outputs["system"], str) + assert isinstance(r_simulation.outputs["state"], str) + assert isinstance(r_simulation.outputs["integrator"], str) + + end_states[simulation.name] = r_simulation.outputs["state"] + + dag = protocol.create( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + name="Short vacuum transformation, but extended", + mapping=mapping_benzene_toluene, + extends=ProtocolDAGResult.from_dict(pdr.to_dict()), + ) + + with tmpdir.as_cwd(): + + base_path = Path("extended") + + shared = base_path / "shared" + shared.mkdir(parents=True) + + scratch = base_path / "scratch" + scratch.mkdir(parents=True) + pdr: ProtocolDAGResult = execute_DAG( + dag, shared_basedir=shared, scratch_basedir=scratch + ) + + r_setup = pdr.protocol_unit_results[0] + + assert r_setup.inputs["extends_data"] != {} + + for cycle in range(protocol.settings.num_cycles): + cycle = str(cycle) + assert isinstance(r_setup.inputs["extends_data"]["systems"][cycle], str) + assert isinstance(r_setup.inputs["extends_data"]["states"][cycle], str) + assert isinstance(r_setup.inputs["extends_data"]["integrators"][cycle], str) + + assert r_setup.inputs["extends_data"]["states"][cycle] == end_states[cycle] + # TODO: We probably need to find failure test cases as control # def test_dag_execute_failure(self, protocol_dag_broken): # protocol, dag, dagfailure = protocol_dag_broken diff --git a/feflow/utils/data.py b/feflow/utils/data.py index f829346..64665d0 100644 --- a/feflow/utils/data.py +++ b/feflow/utils/data.py @@ -1,5 +1,48 @@ import os import pathlib +import bz2 +import base64 + +from openmm import XmlSerializer + + +def serialize_and_compress(item) -> str: + """Serialize an OpenMM System, State, or Integrator and compress. + + Parameters + ---------- + item : System, State, or Integrator + The OpenMM object to serialize and compress. + + Returns + ------- + b64string : str + The compressed serialized OpenMM object encoded in a Base64 string. + """ + serialized = XmlSerializer.serialize(item).encode() + compressed = bz2.compress(serialized) + b64string = base64.b64encode(compressed).decode("ascii") + return b64string + + +def decompress_and_deserialize(data: str): + """Recover an OpenMM object from compression. + + Parameters + ---------- + data : str + String containing a Base64 encoded bzip2 compressed XML serialization + of an OpenMM object. + + Returns + ------- + deserialized + The deserialized OpenMM object. + """ + compressed = base64.b64decode(data) + decompressed = bz2.decompress(compressed).decode("utf-8") + deserialized = XmlSerializer.deserialize(decompressed) + return deserialized def serialize(item, filename: pathlib.Path): @@ -13,7 +56,6 @@ def serialize(item, filename: pathlib.Path): filename : str The filename to serialize to """ - from openmm import XmlSerializer # Create parent directory if it doesn't exist filename_basedir = filename.parent