diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index f482592a4..38bec07da 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -602,7 +602,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[dict[str, gufe.ComponentMapping]] = None, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, extends: Optional[gufe.ProtocolDAGResult] = None, ) -> list[gufe.ProtocolUnit]: # TODO: extensions diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index e1c7bb3eb..8fdcb5d00 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -173,7 +173,7 @@ def _get_alchemical_charge_difference( def _validate_alchemical_components( alchemical_components: dict[str, list[Component]], - mapping: Optional[dict[str, ComponentMapping]], + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], ): """ Checks that the alchemical components are suitable for the RFE protocol. @@ -188,8 +188,8 @@ def _validate_alchemical_components( alchemical_components : dict[str, list[Component]] Dictionary contatining the alchemical components for states A and B. - mapping : dict[str, ComponentMapping] - Dictionary of mappings between transforming components. + mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] + all mappings between transforming components. Raises ------ @@ -201,16 +201,17 @@ def _validate_alchemical_components( UserWarning * Mappings which involve element changes in core atoms """ + if isinstance(mapping, ComponentMapping): + mapping = [mapping] # Check mapping # For now we only allow for a single mapping, this will likely change - if mapping is None or len(mapping.values()) > 1: + if mapping is None or len(mapping) != 1: errmsg = "A single LigandAtomMapping is expected for this Protocol" raise ValueError(errmsg) # Check that all alchemical components are mapped & small molecules - mapped = {} - mapped['stateA'] = [m.componentA for m in mapping.values()] - mapped['stateB'] = [m.componentB for m in mapping.values()] + mapped = {'stateA': [m.componentA for m in mapping], + 'stateB': [m.componentB for m in mapping]} for idx in ['stateA', 'stateB']: if len(alchemical_components[idx]) != len(mapped[idx]): @@ -226,7 +227,7 @@ def _validate_alchemical_components( raise ValueError(errmsg) # Validate element changes in mappings - for m in mapping.values(): + for m in mapping: molA = m.componentA.to_rdkit() molB = m.componentB.to_rdkit() for i, j in m.componentA_to_componentB.items(): @@ -470,7 +471,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[dict[str, gufe.ComponentMapping]] = None, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]], extends: Optional[gufe.ProtocolDAGResult] = None, ) -> list[gufe.ProtocolUnit]: # TODO: Extensions? @@ -482,9 +483,7 @@ def _create( stateA, stateB ) _validate_alchemical_components(alchem_comps, mapping) - - # For now we've made it fail already if it was None, - ligandmapping = list(mapping.values())[0] # type: ignore + ligandmapping = mapping[0] if isinstance(mapping, list) else mapping # type: ignore # Validate solvent component nonbond = self.settings.forcefield_settings.nonbonded_method @@ -500,7 +499,8 @@ def _create( n_repeats = self.settings.protocol_repeats units = [RelativeHybridTopologyProtocolUnit( protocol=self, - stateA=stateA, stateB=stateB, ligandmapping=ligandmapping, + stateA=stateA, stateB=stateB, + ligandmapping=ligandmapping, # type: ignore generation=0, repeat_id=int(uuid.uuid4()), name=f'{Anames} to {Bnames} repeat {i} generation 0') for i in range(n_repeats)] diff --git a/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py b/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py index cda2cb720..3dc060e91 100644 --- a/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py +++ b/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py @@ -218,7 +218,7 @@ def _build_transformation( return Transformation( stateA=stateA, stateB=stateB, - mapping={RFEComponentLabels.LIGAND: ligand_mapping_edge}, + mapping=ligand_mapping_edge, name=transformation_name, protocol=transformation_protocol, ) diff --git a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py index 869672efe..1f0d3daaa 100644 --- a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py +++ b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py @@ -1,40 +1,36 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import os -from io import StringIO import copy -import numpy as np -import gufe -from gufe.tests.test_tokenization import GufeTokenizableTestsMixin import json -import pytest -from unittest import mock -from openff.units.openmm import to_openmm, from_openmm -from openff.units import unit -from importlib import resources import xml.etree.ElementTree as ET +from importlib import resources +from unittest import mock +import gufe +import mdtraj as mdt +import numpy as np +import pytest +from openff.units import unit +from openff.units.openmm import ensure_quantity +from openff.units.openmm import to_openmm, from_openmm from openmm import ( app, XmlSerializer, MonteCarloBarostat, NonbondedForce, CustomNonbondedForce ) from openmm import unit as omm_unit +from openmmforcefields.generators import SMIRNOFFTemplateGenerator from openmmtools.multistate.multistatesampler import MultiStateSampler -import pathlib from rdkit import Chem from rdkit.Geometry import Point3D -import mdtraj as mdt import openfe from openfe import setup from openfe.protocols import openmm_rfe +from openfe.protocols.openmm_rfe._rfe_utils import topologyhelpers from openfe.protocols.openmm_rfe.equil_rfe_methods import ( - _validate_alchemical_components, _get_alchemical_charge_difference + _validate_alchemical_components, _get_alchemical_charge_difference ) -from openfe.protocols.openmm_rfe._rfe_utils import topologyhelpers from openfe.protocols.openmm_utils import system_creation -from openmmforcefields.generators import SMIRNOFFTemplateGenerator -from openff.units.openmm import ensure_quantity def test_compute_platform_warn(): @@ -137,12 +133,12 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t dag1 = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) dag2 = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) repeat_ids = set() @@ -156,7 +152,7 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t @pytest.mark.parametrize('mapping', [ - None, {'A': 'Foo', 'B': 'bar'}, + None, [], ['A', 'B'], ]) def test_validate_alchemical_components_wrong_mappings(mapping): with pytest.raises(ValueError, match="A single LigandAtomMapping"): @@ -170,7 +166,7 @@ def test_validate_alchemical_components_missing_alchem_comp( alchem_comps = {'stateA': [openfe.SolventComponent(), ], 'stateB': []} with pytest.raises(ValueError, match="Unmapped alchemical component"): _validate_alchemical_components( - alchem_comps, {'ligand': benzene_to_toluene_mapping}, + alchem_comps, benzene_to_toluene_mapping, ) @@ -193,7 +189,7 @@ def test_dry_run_default_vacuum(benzene_vacuum_system, toluene_vacuum_system, dag = protocol.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -237,7 +233,7 @@ def test_dry_run_gaff_vacuum(benzene_vacuum_system, toluene_vacuum_system, dag = protocol.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) unit = list(dag.protocol_units)[0] @@ -263,7 +259,7 @@ def test_dry_many_molecules_solvent( dag = protocol.create( stateA=benzene_many_solv_system, stateB=toluene_many_solv_system, - mapping={'spicyligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) unit = list(dag.protocol_units)[0] @@ -357,7 +353,7 @@ def test_dry_core_element_change(tmpdir): dag = protocol.create( stateA=openfe.ChemicalSystem({'ligand': benz, }), stateB=openfe.ChemicalSystem({'ligand': pyr, }), - mapping={'whatamapping': mapping}, + mapping=mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -393,7 +389,7 @@ def test_dry_run_ligand(benzene_system, toluene_system, dag = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -421,7 +417,7 @@ def test_confgen_mocked_fail(benzene_system, toluene_system, protocol = openmm_rfe.RelativeHybridTopologyProtocol(settings=settings) dag = protocol.create(stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}) + mapping=benzene_to_toluene_mapping) dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): @@ -456,7 +452,7 @@ def tip4p_hybrid_factory( dag = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -643,7 +639,7 @@ def check_propchgs(smc, charge_array): dag = protocol.create( stateA=openfe.ChemicalSystem({'l': benzene_smc, }), stateB=openfe.ChemicalSystem({'l': toluene_smc, }), - mapping={'ligand': mapping}, + mapping=mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -737,7 +733,7 @@ def test_virtual_sites_no_reassign(benzene_system, toluene_system, dag = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -763,7 +759,7 @@ def test_dry_run_complex(benzene_complex_system, toluene_complex_system, dag = protocol.create( stateA=benzene_complex_system, stateB=toluene_complex_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -806,7 +802,7 @@ def test_hightimestep(benzene_vacuum_system, dag = p.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] @@ -837,7 +833,7 @@ def test_n_replicas_not_n_windows(benzene_vacuum_system, dag = p.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) dag_unit = list(dag.protocol_units)[0] dag_unit.run(dry=True) @@ -856,7 +852,7 @@ def test_missing_ligand(benzene_system, benzene_to_toluene_mapping): _ = p.create( stateA=benzene_system, stateB=stateB, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) @@ -873,7 +869,7 @@ def test_vaccuum_PME_error(benzene_vacuum_system, benzene_modifications, _ = p.create( stateA=benzene_vacuum_system, stateB=stateB, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) @@ -896,7 +892,7 @@ def test_incompatible_solvent(benzene_system, benzene_modifications, _ = p.create( stateA=benzene_system, stateB=stateB, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) @@ -917,7 +913,7 @@ def test_mapping_mismatch_A(benzene_system, toluene_system, _ = p.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': mapping}, + mapping=mapping, ) @@ -937,7 +933,7 @@ def test_mapping_mismatch_B(benzene_system, toluene_system, _ = p.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': mapping}, + mapping=mapping, ) @@ -951,12 +947,12 @@ def test_complex_mismatch(benzene_system, toluene_complex_system, _ = p.create( stateA=benzene_system, stateB=toluene_complex_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) def test_too_many_specified_mappings(benzene_system, toluene_system, - benzene_to_toluene_mapping): + benzene_to_toluene_mapping): # mapping dict requires 'ligand' key p = openmm_rfe.RelativeHybridTopologyProtocol( settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), @@ -966,8 +962,8 @@ def test_too_many_specified_mappings(benzene_system, toluene_system, _ = p.create( stateA=benzene_system, stateB=toluene_system, - mapping={'solvent': benzene_to_toluene_mapping, - 'ligand': benzene_to_toluene_mapping, } + mapping=[benzene_to_toluene_mapping, + benzene_to_toluene_mapping], ) @@ -990,7 +986,7 @@ def test_protein_mismatch(benzene_complex_system, toluene_complex_system, _ = p.create( stateA=benzene_complex_system, stateB=alt_toluene_complex_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) @@ -1015,7 +1011,7 @@ def test_element_change_warning(atom_mapping_basic_test_files): with pytest.warns(UserWarning, match="Element change"): _ = p.create( stateA=sys1, stateB=sys2, - mapping={'ligand': mapping}, + mapping=mapping, ) @@ -1052,7 +1048,7 @@ def test_ligand_overlap_warning(benzene_vacuum_system, toluene_vacuum_system, with pytest.warns(UserWarning, match='0 : 4 deviates'): dag = protocol.create( stateA=sysA, stateB=toluene_vacuum_system, - mapping={'ligand': mapping}, + mapping=mapping, ) dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): @@ -1067,7 +1063,7 @@ def solvent_protocol_dag(benzene_system, toluene_system, benzene_to_toluene_mapp ) return protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=benzene_to_toluene_mapping, ) @@ -1340,7 +1336,7 @@ def tyk2_xml(tmp_path_factory): dag = protocol.create( stateA=openfe.ChemicalSystem({'ligand': lig23}), stateB=openfe.ChemicalSystem({'ligand': lig55}), - mapping={'ligand': mapping}, + mapping=mapping, ) pu = list(dag.protocol_units)[0] @@ -1883,7 +1879,7 @@ def test_dry_run_alchemwater_solvent(benzene_to_benzoic_mapping, tmpdir): dag = protocol.create( stateA=stateA_system, stateB=stateB_system, - mapping={'ligand': benzene_to_benzoic_mapping}, + mapping=benzene_to_benzoic_mapping, ) unit = list(dag.protocol_units)[0] @@ -1936,7 +1932,7 @@ def test_dry_run_complex_alchemwater_totcharge( dag = protocol.create( stateA=stateA_system, stateB=stateB_system, - mapping={'ligand': mapping}, + mapping=mapping, ) unit = list(dag.protocol_units)[0] diff --git a/openfe/tests/protocols/test_openmm_rfe_slow.py b/openfe/tests/protocols/test_openmm_rfe_slow.py index 4e8a8b522..26160b7c5 100644 --- a/openfe/tests/protocols/test_openmm_rfe_slow.py +++ b/openfe/tests/protocols/test_openmm_rfe_slow.py @@ -67,7 +67,7 @@ def test_openmm_run_engine(benzene_vacuum_system, platform, m = openfe.LigandAtomMapping(componentA=b, componentB=b_alt, componentA_to_componentB={i: i for i in range(12)}) dag = p.create(stateA=benzene_vacuum_system, stateB=benzene_vacuum_alt_system, - mapping={'ligand': m}) + mapping=[m]) cwd = pathlib.Path(str(tmpdir)) r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, @@ -128,7 +128,7 @@ def test_run_eg5_sim(eg5_protein, eg5_ligands, eg5_cofactor, tmpdir): sys2 = openfe.ChemicalSystem(components={**base_sys, 'ligand': l2}) dag = p.create(stateA=sys1, stateB=sys2, - mapping={'ligand': m}) + mapping=[m]) cwd = pathlib.Path(str(tmpdir)) r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, diff --git a/openfe/tests/protocols/test_rfe_tokenization.py b/openfe/tests/protocols/test_rfe_tokenization.py index 951a22c6e..7e017eca4 100644 --- a/openfe/tests/protocols/test_rfe_tokenization.py +++ b/openfe/tests/protocols/test_rfe_tokenization.py @@ -21,7 +21,7 @@ def protocol(): def protocol_unit(protocol, benzene_system, toluene_system, benzene_to_toluene_mapping): pus = protocol.create( stateA=benzene_system, stateB=toluene_system, - mapping={'ligand': benzene_to_toluene_mapping}, + mapping=[benzene_to_toluene_mapping], ) return list(pus.protocol_units)[0]