From 0ba93459a4982f1b6f7be39f15feccc9791a9361 Mon Sep 17 00:00:00 2001 From: richard gowers Date: Mon, 18 Dec 2023 18:01:40 +0000 Subject: [PATCH 1/8] consequences of gufe #260 change all mappings to Protocol.create/Transformation.__init__ are now lists of mappings, not a dict --- .../openmm_afe/equil_solvation_afe_method.py | 2 +- .../protocols/openmm_rfe/equil_rfe_methods.py | 32 ++++++---- .../relative_alchemical_network_planner.py | 2 +- .../test_openmm_equil_rfe_protocols.py | 64 +++++++++---------- .../tests/protocols/test_openmm_rfe_slow.py | 4 +- .../tests/protocols/test_rfe_tokenization.py | 2 +- 6 files changed, 55 insertions(+), 51 deletions(-) diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index 6de7daa1d..66d943e09 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -521,7 +521,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[Dict[str, gufe.ComponentMapping]] = None, + mapping: Optional[list[gufe.ComponentMapping]] = None, extends: Optional[gufe.ProtocolDAGResult] = None, ) -> list[gufe.ProtocolUnit]: # TODO: extensions diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 1cc92e8ce..84c35ba4a 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -173,8 +173,8 @@ def _get_alchemical_charge_difference( def _validate_alchemical_components( alchemical_components: dict[str, list[Component]], - mapping: Optional[dict[str, ComponentMapping]], -): + mapping: Optional[list[ComponentMapping]], +) -> LigandAtomMapping: """ Checks that the alchemical components are suitable for the RFE protocol. @@ -188,9 +188,14 @@ def _validate_alchemical_components( alchemical_components : dict[str, list[Component]] Dictionary contatining the alchemical components for states A and B. - mapping : dict[str, ComponentMapping] + mapping : Optional[list[ComponentMapping]] Dictionary of mappings between transforming components. + Returns + ------- + mapping : LigandAtomMapping + if all the above checks pass, returns the single mapping + Raises ------ ValueError @@ -203,14 +208,13 @@ def _validate_alchemical_components( """ # Check mapping # For now we only allow for a single mapping, this will likely change - if mapping is None or len(mapping.values()) > 1: + if mapping is None or len(mapping) > 1: errmsg = "A single LigandAtomMapping is expected for this Protocol" raise ValueError(errmsg) # Check that all alchemical components are mapped & small molecules - mapped = {} - mapped['stateA'] = [m.componentA for m in mapping.values()] - mapped['stateB'] = [m.componentB for m in mapping.values()] + mapped = {'stateA': [m.componentA for m in mapping], + 'stateB': [m.componentB for m in mapping]} for idx in ['stateA', 'stateB']: if len(alchemical_components[idx]) != len(mapped[idx]): @@ -226,7 +230,7 @@ def _validate_alchemical_components( raise ValueError(errmsg) # Validate element changes in mappings - for m in mapping.values(): + for m in mapping: molA = m.componentA.to_rdkit() molB = m.componentB.to_rdkit() for i, j in m.componentA_to_componentB.items(): @@ -243,6 +247,8 @@ def _validate_alchemical_components( logger.warning(wmsg) warnings.warn(wmsg) # TODO: remove this once logging is fixed + return mapping[0] # type: ignore + class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): """Dict-like container for the output of a RelativeHybridTopologyProtocol""" @@ -469,7 +475,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[dict[str, gufe.ComponentMapping]] = None, + mapping: Optional[list[gufe.ComponentMapping]] = None, extends: Optional[gufe.ProtocolDAGResult] = None, ) -> list[gufe.ProtocolUnit]: # TODO: Extensions? @@ -480,10 +486,7 @@ def _create( alchem_comps = system_validation.get_alchemical_components( stateA, stateB ) - _validate_alchemical_components(alchem_comps, mapping) - - # For now we've made it fail already if it was None, - ligandmapping = list(mapping.values())[0] # type: ignore + ligandmapping = _validate_alchemical_components(alchem_comps, mapping) # Validate solvent component nonbond = self.settings.system_settings.nonbonded_method @@ -498,7 +501,8 @@ def _create( # our DAG has no dependencies, so just list units n_repeats = self.settings.alchemical_sampler_settings.n_repeats units = [RelativeHybridTopologyProtocolUnit( - stateA=stateA, stateB=stateB, ligandmapping=ligandmapping, + stateA=stateA, stateB=stateB, + ligandmapping=ligandmapping, settings=self.settings, generation=0, repeat_id=int(uuid.uuid4()), name=f'{Anames} to {Bnames} repeat {i} generation 0') diff --git a/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py b/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py index 3b34aa511..f66e4f7fc 100644 --- a/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py +++ b/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py @@ -218,7 +218,7 @@ def _build_transformation( return Transformation( stateA=stateA, stateB=stateB, - mapping={RFEComponentLabels.LIGAND: ligand_mapping_edge}, + mapping=[ligand_mapping_edge], name=transformation_name, protocol=transformation_protocol, ) diff --git a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py index 509c38e84..17f127b3a 100644 --- a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py +++ b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py @@ -137,12 +137,12 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t dag1 = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) dag2 = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) repeat_ids = set() @@ -156,7 +156,7 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t @pytest.mark.parametrize('mapping', [ - None, {'A': 'Foo', 'B': 'bar'}, + None, ['A', 'B'], ]) def test_validate_alchemical_components_wrong_mappings(mapping): with pytest.raises(ValueError, match="A single LigandAtomMapping"): @@ -170,7 +170,7 @@ def test_validate_alchemical_components_missing_alchem_comp( alchem_comps = {'stateA': [openfe.SolventComponent(),], 'stateB': []} with pytest.raises(ValueError, match="Unmapped alchemical component"): _validate_alchemical_components( - alchem_comps, {'ligand': benzene_to_toluene_mapping}, + alchem_comps, [benzene_to_toluene_mapping], ) @@ -193,7 +193,7 @@ def test_dry_run_default_vacuum(benzene_vacuum_system, toluene_vacuum_system, dag = protocol.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) dag_unit = list(dag.protocol_units)[0] @@ -237,7 +237,7 @@ def test_dry_run_gaff_vacuum(benzene_vacuum_system, toluene_vacuum_system, dag = protocol.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) unit = list(dag.protocol_units)[0] @@ -263,7 +263,7 @@ def test_dry_many_molecules_solvent( dag = protocol.create( stateA=benzene_many_solv_system, stateB=toluene_many_solv_system, - mapping={'spicyligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) unit = list(dag.protocol_units)[0] @@ -357,7 +357,7 @@ def test_dry_core_element_change(tmpdir): dag = protocol.create( stateA=openfe.ChemicalSystem({'ligand': benz,}), stateB=openfe.ChemicalSystem({'ligand': pyr,}), - mapping={'whatamapping': mapping}, + mapping=[mapping], ) dag_unit = list(dag.protocol_units)[0] @@ -393,7 +393,7 @@ def test_dry_run_ligand(benzene_system, toluene_system, dag = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) dag_unit = list(dag.protocol_units)[0] @@ -421,7 +421,7 @@ def test_confgen_mocked_fail(benzene_system, toluene_system, protocol = openmm_rfe.RelativeHybridTopologyProtocol(settings=settings) dag = protocol.create(stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}) + mapping=[benzene_to_toluene_mapping]) dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): @@ -456,7 +456,7 @@ def tip4p_hybrid_factory( dag = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) dag_unit = list(dag.protocol_units)[0] @@ -643,7 +643,7 @@ def check_propchgs(smc, charge_array): dag = protocol.create( stateA=openfe.ChemicalSystem({'l': benzene_smc,}), stateB=openfe.ChemicalSystem({'l': toluene_smc,}), - mapping={'ligand': mapping}, + mapping=[mapping], ) dag_unit = list(dag.protocol_units)[0] @@ -737,7 +737,7 @@ def test_virtual_sites_no_reassign(benzene_system, toluene_system, dag = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) dag_unit = list(dag.protocol_units)[0] @@ -763,7 +763,7 @@ def test_dry_run_complex(benzene_complex_system, toluene_complex_system, dag = protocol.create( stateA=benzene_complex_system, stateB=toluene_complex_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) dag_unit = list(dag.protocol_units)[0] @@ -806,7 +806,7 @@ def test_hightimestep(benzene_vacuum_system, dag = p.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) dag_unit = list(dag.protocol_units)[0] @@ -837,7 +837,7 @@ def test_n_replicas_not_n_windows(benzene_vacuum_system, dag = p.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) dag_unit = list(dag.protocol_units)[0] dag_unit.run(dry=True) @@ -856,7 +856,7 @@ def test_missing_ligand(benzene_system, benzene_to_toluene_mapping): _ = p.create( stateA=benzene_system, stateB=stateB, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) @@ -873,7 +873,7 @@ def test_vaccuum_PME_error(benzene_vacuum_system, benzene_modifications, _ = p.create( stateA=benzene_vacuum_system, stateB=stateB, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) @@ -896,7 +896,7 @@ def test_incompatible_solvent(benzene_system, benzene_modifications, _ = p.create( stateA=benzene_system, stateB=stateB, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) @@ -917,7 +917,7 @@ def test_mapping_mismatch_A(benzene_system, toluene_system, _ = p.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': mapping}, + mapping=[mapping], ) @@ -937,7 +937,7 @@ def test_mapping_mismatch_B(benzene_system, toluene_system, _ = p.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': mapping}, + mapping=[mapping], ) @@ -951,12 +951,12 @@ def test_complex_mismatch(benzene_system, toluene_complex_system, _ = p.create( stateA=benzene_system, stateB=toluene_complex_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) def test_too_many_specified_mappings(benzene_system, toluene_system, - benzene_to_toluene_mapping): + benzene_to_toluene_mapping): # mapping dict requires 'ligand' key p = openmm_rfe.RelativeHybridTopologyProtocol( settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), @@ -966,8 +966,8 @@ def test_too_many_specified_mappings(benzene_system, toluene_system, _ = p.create( stateA=benzene_system, stateB=toluene_system, - mapping={'solvent': benzene_to_toluene_mapping, - 'ligand': benzene_to_toluene_mapping,} + mapping=[benzene_to_toluene_mapping, + benzene_to_toluene_mapping], ) @@ -990,7 +990,7 @@ def test_protein_mismatch(benzene_complex_system, toluene_complex_system, _ = p.create( stateA=benzene_complex_system, stateB=alt_toluene_complex_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) @@ -1015,7 +1015,7 @@ def test_element_change_warning(atom_mapping_basic_test_files): with pytest.warns(UserWarning, match="Element change"): _ = p.create( stateA=sys1, stateB=sys2, - mapping={'ligand': mapping}, + mapping=[mapping], ) @@ -1052,7 +1052,7 @@ def test_ligand_overlap_warning(benzene_vacuum_system, toluene_vacuum_system, with pytest.warns(UserWarning, match='0 : 4 deviates'): dag = protocol.create( stateA=sysA, stateB=toluene_vacuum_system, - mapping={'ligand': mapping}, + mapping=[mapping], ) dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): @@ -1069,7 +1069,7 @@ def solvent_protocol_dag(benzene_system, toluene_system, benzene_to_toluene_mapp return protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) @@ -1343,7 +1343,7 @@ def tyk2_xml(tmp_path_factory): dag = protocol.create( stateA=openfe.ChemicalSystem({'ligand': lig23}), stateB=openfe.ChemicalSystem({'ligand': lig55}), - mapping={'ligand': mapping}, + mapping=[mapping], ) pu = list(dag.protocol_units)[0] @@ -1888,7 +1888,7 @@ def test_dry_run_alchemwater_solvent(benzene_to_benzoic_mapping, tmpdir): dag = protocol.create( stateA=stateA_system, stateB=stateB_system, - mapping={'ligand': benzene_to_benzoic_mapping}, + mapping=[benzene_to_benzoic_mapping], ) unit = list(dag.protocol_units)[0] @@ -1941,7 +1941,7 @@ def test_dry_run_complex_alchemwater_totcharge( dag = protocol.create( stateA=stateA_system, stateB=stateB_system, - mapping={'ligand': mapping}, + mapping=[mapping], ) unit = list(dag.protocol_units)[0] diff --git a/openfe/tests/protocols/test_openmm_rfe_slow.py b/openfe/tests/protocols/test_openmm_rfe_slow.py index 2af72929c..fe0a07250 100644 --- a/openfe/tests/protocols/test_openmm_rfe_slow.py +++ b/openfe/tests/protocols/test_openmm_rfe_slow.py @@ -67,7 +67,7 @@ def test_openmm_run_engine(benzene_vacuum_system, platform, m = openfe.LigandAtomMapping(componentA=b, componentB=b_alt, componentA_to_componentB={i: i for i in range(12)}) dag = p.create(stateA=benzene_vacuum_system, stateB=benzene_vacuum_alt_system, - mapping={'ligand': m}) + mapping=[m]) cwd = pathlib.Path(str(tmpdir)) r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, @@ -128,7 +128,7 @@ def test_run_eg5_sim(eg5_protein, eg5_ligands, eg5_cofactor, tmpdir): sys2 = openfe.ChemicalSystem(components={**base_sys, 'ligand': l2}) dag = p.create(stateA=sys1, stateB=sys2, - mapping={'ligand': m}) + mapping=[m]) cwd = pathlib.Path(str(tmpdir)) r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, diff --git a/openfe/tests/protocols/test_rfe_tokenization.py b/openfe/tests/protocols/test_rfe_tokenization.py index 2ff65982e..7d18a4ca1 100644 --- a/openfe/tests/protocols/test_rfe_tokenization.py +++ b/openfe/tests/protocols/test_rfe_tokenization.py @@ -21,7 +21,7 @@ def protocol(): def protocol_unit(protocol, benzene_system, toluene_system, benzene_to_toluene_mapping): pus = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) return list(pus.protocol_units)[0] From 57fc2fe2b0f4ff051b77e367eb703f937eb4654a Mon Sep 17 00:00:00 2001 From: richard gowers Date: Tue, 16 Jan 2024 14:49:11 +0000 Subject: [PATCH 2/8] final fixups for new _create/create gufe API --- .../openmm_afe/equil_solvation_afe_method.py | 2 +- .../protocols/openmm_rfe/equil_rfe_methods.py | 26 +++--- .../relative_alchemical_network_planner.py | 2 +- .../test_openmm_equil_rfe_protocols.py | 84 +++++++++---------- 4 files changed, 52 insertions(+), 62 deletions(-) diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index 66d943e09..61c9f6058 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -521,7 +521,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[list[gufe.ComponentMapping]] = None, + mapping: list[gufe.ComponentMapping], extends: Optional[gufe.ProtocolDAGResult] = None, ) -> list[gufe.ProtocolUnit]: # TODO: extensions diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 84c35ba4a..0627da047 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -173,8 +173,8 @@ def _get_alchemical_charge_difference( def _validate_alchemical_components( alchemical_components: dict[str, list[Component]], - mapping: Optional[list[ComponentMapping]], -) -> LigandAtomMapping: + mapping: list[ComponentMapping], +): """ Checks that the alchemical components are suitable for the RFE protocol. @@ -188,13 +188,8 @@ def _validate_alchemical_components( alchemical_components : dict[str, list[Component]] Dictionary contatining the alchemical components for states A and B. - mapping : Optional[list[ComponentMapping]] - Dictionary of mappings between transforming components. - - Returns - ------- - mapping : LigandAtomMapping - if all the above checks pass, returns the single mapping + mapping : list[ComponentMapping] + all mappings between transforming components. Raises ------ @@ -208,7 +203,7 @@ def _validate_alchemical_components( """ # Check mapping # For now we only allow for a single mapping, this will likely change - if mapping is None or len(mapping) > 1: + if len(mapping) != 1: errmsg = "A single LigandAtomMapping is expected for this Protocol" raise ValueError(errmsg) @@ -247,8 +242,6 @@ def _validate_alchemical_components( logger.warning(wmsg) warnings.warn(wmsg) # TODO: remove this once logging is fixed - return mapping[0] # type: ignore - class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): """Dict-like container for the output of a RelativeHybridTopologyProtocol""" @@ -475,7 +468,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[list[gufe.ComponentMapping]] = None, + mapping: list[gufe.ComponentMapping], extends: Optional[gufe.ProtocolDAGResult] = None, ) -> list[gufe.ProtocolUnit]: # TODO: Extensions? @@ -486,7 +479,8 @@ def _create( alchem_comps = system_validation.get_alchemical_components( stateA, stateB ) - ligandmapping = _validate_alchemical_components(alchem_comps, mapping) + _validate_alchemical_components(alchem_comps, mapping) + ligandmapping = mapping[0] # type: ignore # Validate solvent component nonbond = self.settings.system_settings.nonbonded_method @@ -502,8 +496,8 @@ def _create( n_repeats = self.settings.alchemical_sampler_settings.n_repeats units = [RelativeHybridTopologyProtocolUnit( stateA=stateA, stateB=stateB, - ligandmapping=ligandmapping, - settings=self.settings, + ligandmapping=ligandmapping, # type: ignore + settings=self.settings, # type: ignore generation=0, repeat_id=int(uuid.uuid4()), name=f'{Anames} to {Bnames} repeat {i} generation 0') for i in range(n_repeats)] diff --git a/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py b/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py index f66e4f7fc..bbe22eb32 100644 --- a/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py +++ b/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py @@ -218,7 +218,7 @@ def _build_transformation( return Transformation( stateA=stateA, stateB=stateB, - mapping=[ligand_mapping_edge], + mapping=ligand_mapping_edge, name=transformation_name, protocol=transformation_protocol, ) diff --git a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py index 17f127b3a..c35e9eb16 100644 --- a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py +++ b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py @@ -1,40 +1,36 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import os -from io import StringIO import copy -import numpy as np -import gufe -from gufe.tests.test_tokenization import GufeTokenizableTestsMixin import json -import pytest -from unittest import mock -from openff.units.openmm import to_openmm, from_openmm -from openff.units import unit -from importlib import resources import xml.etree.ElementTree as ET +from importlib import resources +from unittest import mock +import gufe +import mdtraj as mdt +import numpy as np +import pytest +from openff.units import unit +from openff.units.openmm import ensure_quantity +from openff.units.openmm import to_openmm, from_openmm from openmm import ( app, XmlSerializer, MonteCarloBarostat, NonbondedForce, CustomNonbondedForce ) from openmm import unit as omm_unit +from openmmforcefields.generators import SMIRNOFFTemplateGenerator from openmmtools.multistate.multistatesampler import MultiStateSampler -import pathlib from rdkit import Chem from rdkit.Geometry import Point3D -import mdtraj as mdt import openfe from openfe import setup from openfe.protocols import openmm_rfe +from openfe.protocols.openmm_rfe._rfe_utils import topologyhelpers from openfe.protocols.openmm_rfe.equil_rfe_methods import ( - _validate_alchemical_components, _get_alchemical_charge_difference + _validate_alchemical_components, _get_alchemical_charge_difference ) -from openfe.protocols.openmm_rfe._rfe_utils import topologyhelpers from openfe.protocols.openmm_utils import system_creation -from openmmforcefields.generators import SMIRNOFFTemplateGenerator -from openff.units.openmm import ensure_quantity def test_compute_platform_warn(): @@ -137,12 +133,12 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t dag1 = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) dag2 = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) repeat_ids = set() @@ -156,7 +152,7 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t @pytest.mark.parametrize('mapping', [ - None, ['A', 'B'], + [], ['A', 'B'], ]) def test_validate_alchemical_components_wrong_mappings(mapping): with pytest.raises(ValueError, match="A single LigandAtomMapping"): @@ -193,7 +189,7 @@ def test_dry_run_default_vacuum(benzene_vacuum_system, toluene_vacuum_system, dag = protocol.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -237,7 +233,7 @@ def test_dry_run_gaff_vacuum(benzene_vacuum_system, toluene_vacuum_system, dag = protocol.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) unit = list(dag.protocol_units)[0] @@ -263,7 +259,7 @@ def test_dry_many_molecules_solvent( dag = protocol.create( stateA=benzene_many_solv_system, stateB=toluene_many_solv_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) unit = list(dag.protocol_units)[0] @@ -357,7 +353,7 @@ def test_dry_core_element_change(tmpdir): dag = protocol.create( stateA=openfe.ChemicalSystem({'ligand': benz,}), stateB=openfe.ChemicalSystem({'ligand': pyr,}), - mapping=[mapping], + mapping=mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -393,7 +389,7 @@ def test_dry_run_ligand(benzene_system, toluene_system, dag = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -421,7 +417,7 @@ def test_confgen_mocked_fail(benzene_system, toluene_system, protocol = openmm_rfe.RelativeHybridTopologyProtocol(settings=settings) dag = protocol.create(stateA=benzene_system, stateB=toluene_system, - mapping=[benzene_to_toluene_mapping]) + mapping=benzene_to_toluene_mapping) dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): @@ -456,7 +452,7 @@ def tip4p_hybrid_factory( dag = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -643,7 +639,7 @@ def check_propchgs(smc, charge_array): dag = protocol.create( stateA=openfe.ChemicalSystem({'l': benzene_smc,}), stateB=openfe.ChemicalSystem({'l': toluene_smc,}), - mapping=[mapping], + mapping=mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -737,7 +733,7 @@ def test_virtual_sites_no_reassign(benzene_system, toluene_system, dag = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -763,7 +759,7 @@ def test_dry_run_complex(benzene_complex_system, toluene_complex_system, dag = protocol.create( stateA=benzene_complex_system, stateB=toluene_complex_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -806,7 +802,7 @@ def test_hightimestep(benzene_vacuum_system, dag = p.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -837,7 +833,7 @@ def test_n_replicas_not_n_windows(benzene_vacuum_system, dag = p.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] dag_unit.run(dry=True) @@ -856,7 +852,7 @@ def test_missing_ligand(benzene_system, benzene_to_toluene_mapping): _ = p.create( stateA=benzene_system, stateB=stateB, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) @@ -873,7 +869,7 @@ def test_vaccuum_PME_error(benzene_vacuum_system, benzene_modifications, _ = p.create( stateA=benzene_vacuum_system, stateB=stateB, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) @@ -896,7 +892,7 @@ def test_incompatible_solvent(benzene_system, benzene_modifications, _ = p.create( stateA=benzene_system, stateB=stateB, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) @@ -917,7 +913,7 @@ def test_mapping_mismatch_A(benzene_system, toluene_system, _ = p.create( stateA=benzene_system, stateB=toluene_system, - mapping=[mapping], + mapping=mapping, ) @@ -937,7 +933,7 @@ def test_mapping_mismatch_B(benzene_system, toluene_system, _ = p.create( stateA=benzene_system, stateB=toluene_system, - mapping=[mapping], + mapping=mapping, ) @@ -951,7 +947,7 @@ def test_complex_mismatch(benzene_system, toluene_complex_system, _ = p.create( stateA=benzene_system, stateB=toluene_complex_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) @@ -990,7 +986,7 @@ def test_protein_mismatch(benzene_complex_system, toluene_complex_system, _ = p.create( stateA=benzene_complex_system, stateB=alt_toluene_complex_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) @@ -1015,7 +1011,7 @@ def test_element_change_warning(atom_mapping_basic_test_files): with pytest.warns(UserWarning, match="Element change"): _ = p.create( stateA=sys1, stateB=sys2, - mapping=[mapping], + mapping=mapping, ) @@ -1052,7 +1048,7 @@ def test_ligand_overlap_warning(benzene_vacuum_system, toluene_vacuum_system, with pytest.warns(UserWarning, match='0 : 4 deviates'): dag = protocol.create( stateA=sysA, stateB=toluene_vacuum_system, - mapping=[mapping], + mapping=mapping, ) dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): @@ -1069,7 +1065,7 @@ def solvent_protocol_dag(benzene_system, toluene_system, benzene_to_toluene_mapp return protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping=[benzene_to_toluene_mapping], + mapping=benzene_to_toluene_mapping, ) @@ -1343,7 +1339,7 @@ def tyk2_xml(tmp_path_factory): dag = protocol.create( stateA=openfe.ChemicalSystem({'ligand': lig23}), stateB=openfe.ChemicalSystem({'ligand': lig55}), - mapping=[mapping], + mapping=mapping, ) pu = list(dag.protocol_units)[0] @@ -1888,7 +1884,7 @@ def test_dry_run_alchemwater_solvent(benzene_to_benzoic_mapping, tmpdir): dag = protocol.create( stateA=stateA_system, stateB=stateB_system, - mapping=[benzene_to_benzoic_mapping], + mapping=benzene_to_benzoic_mapping, ) unit = list(dag.protocol_units)[0] @@ -1941,7 +1937,7 @@ def test_dry_run_complex_alchemwater_totcharge( dag = protocol.create( stateA=stateA_system, stateB=stateB_system, - mapping=[mapping], + mapping=mapping, ) unit = list(dag.protocol_units)[0] From e1b1c9c8b2be2053faf32caffcea8ae247449e94 Mon Sep 17 00:00:00 2001 From: richard gowers Date: Fri, 19 Jan 2024 12:26:55 +0000 Subject: [PATCH 3/8] allow None for mappings in _create --- openfe/protocols/openmm_afe/equil_solvation_afe_method.py | 2 +- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index 61c9f6058..80c5a9201 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -521,7 +521,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: list[gufe.ComponentMapping], + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, extends: Optional[gufe.ProtocolDAGResult] = None, ) -> list[gufe.ProtocolUnit]: # TODO: extensions diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 0627da047..e502b3f9d 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -173,7 +173,7 @@ def _get_alchemical_charge_difference( def _validate_alchemical_components( alchemical_components: dict[str, list[Component]], - mapping: list[ComponentMapping], + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], ): """ Checks that the alchemical components are suitable for the RFE protocol. @@ -188,7 +188,7 @@ def _validate_alchemical_components( alchemical_components : dict[str, list[Component]] Dictionary contatining the alchemical components for states A and B. - mapping : list[ComponentMapping] + mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] all mappings between transforming components. Raises @@ -203,7 +203,7 @@ def _validate_alchemical_components( """ # Check mapping # For now we only allow for a single mapping, this will likely change - if len(mapping) != 1: + if mapping is None or len(mapping) > 1: errmsg = "A single LigandAtomMapping is expected for this Protocol" raise ValueError(errmsg) @@ -468,7 +468,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: list[gufe.ComponentMapping], + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]], extends: Optional[gufe.ProtocolDAGResult] = None, ) -> list[gufe.ProtocolUnit]: # TODO: Extensions? From a0651bf7daff401654d3431f01bc8ea5095bfd33 Mon Sep 17 00:00:00 2001 From: richard gowers Date: Fri, 19 Jan 2024 12:30:08 +0000 Subject: [PATCH 4/8] allow None for mappings in _create --- openfe/tests/protocols/test_openmm_equil_rfe_protocols.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py index c35e9eb16..4374a5455 100644 --- a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py +++ b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py @@ -152,7 +152,7 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t @pytest.mark.parametrize('mapping', [ - [], ['A', 'B'], + None, [], ['A', 'B'], ]) def test_validate_alchemical_components_wrong_mappings(mapping): with pytest.raises(ValueError, match="A single LigandAtomMapping"): @@ -166,7 +166,7 @@ def test_validate_alchemical_components_missing_alchem_comp( alchem_comps = {'stateA': [openfe.SolventComponent(),], 'stateB': []} with pytest.raises(ValueError, match="Unmapped alchemical component"): _validate_alchemical_components( - alchem_comps, [benzene_to_toluene_mapping], + alchem_comps, benzene_to_toluene_mapping, ) From 14d4fe6b449c89b4fd3dc149617a9df46044e25e Mon Sep 17 00:00:00 2001 From: richard gowers Date: Wed, 24 Jan 2024 09:47:10 +0000 Subject: [PATCH 5/8] fix case where mapping is a single ComponentMapping --- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index e502b3f9d..ddd5cab28 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -203,7 +203,7 @@ def _validate_alchemical_components( """ # Check mapping # For now we only allow for a single mapping, this will likely change - if mapping is None or len(mapping) > 1: + if mapping is None or (isinstance(mapping, list) and len(mapping) > 1): errmsg = "A single LigandAtomMapping is expected for this Protocol" raise ValueError(errmsg) From b224d7146df41e73b76e6dbad295d29af2ddd664 Mon Sep 17 00:00:00 2001 From: richard gowers Date: Wed, 24 Jan 2024 09:51:23 +0000 Subject: [PATCH 6/8] what if someone used a tuple of mappings? Just be safe and check against ComponentMapping --- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index ddd5cab28..2ecf6bb04 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -203,7 +203,7 @@ def _validate_alchemical_components( """ # Check mapping # For now we only allow for a single mapping, this will likely change - if mapping is None or (isinstance(mapping, list) and len(mapping) > 1): + if mapping is None or (not isinstance(mapping, ComponentMapping) and len(mapping) > 1): errmsg = "A single LigandAtomMapping is expected for this Protocol" raise ValueError(errmsg) From 7a89fb28d88bb0ef42272f479fbe272b7b860341 Mon Sep 17 00:00:00 2001 From: richard gowers Date: Wed, 24 Jan 2024 09:52:42 +0000 Subject: [PATCH 7/8] what if someone used a tuple of mappings? Just be safe and check against ComponentMapping also guard against len 0 list case --- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 2ecf6bb04..9b2cdb284 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -203,7 +203,7 @@ def _validate_alchemical_components( """ # Check mapping # For now we only allow for a single mapping, this will likely change - if mapping is None or (not isinstance(mapping, ComponentMapping) and len(mapping) > 1): + if mapping is None or (not isinstance(mapping, ComponentMapping) and len(mapping) != 1): errmsg = "A single LigandAtomMapping is expected for this Protocol" raise ValueError(errmsg) From 1c9ee3babff9987a3e8aa1532216321c7aa82fd9 Mon Sep 17 00:00:00 2001 From: richard gowers Date: Wed, 7 Feb 2024 10:34:34 +0000 Subject: [PATCH 8/8] fixups following merge handle when mapping is single ComponentMapping not list --- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 4dec79d48..8fdcb5d00 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -201,9 +201,11 @@ def _validate_alchemical_components( UserWarning * Mappings which involve element changes in core atoms """ + if isinstance(mapping, ComponentMapping): + mapping = [mapping] # Check mapping # For now we only allow for a single mapping, this will likely change - if mapping is None or (not isinstance(mapping, ComponentMapping) and len(mapping) != 1): + if mapping is None or len(mapping) != 1: errmsg = "A single LigandAtomMapping is expected for this Protocol" raise ValueError(errmsg) @@ -481,7 +483,7 @@ def _create( stateA, stateB ) _validate_alchemical_components(alchem_comps, mapping) - ligandmapping = mapping[0] # type: ignore + ligandmapping = mapping[0] if isinstance(mapping, list) else mapping # type: ignore # Validate solvent component nonbond = self.settings.forcefield_settings.nonbonded_method