Skip to content

Commit

Permalink
Merge pull request #666 from OpenFreeEnergy/gufe_260_consequences
Browse files Browse the repository at this point in the history
consequences of gufe #260 change
  • Loading branch information
IAlibay authored Feb 8, 2024
2 parents 00ef4de + 1c9ee3b commit 0815897
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 66 deletions.
2 changes: 1 addition & 1 deletion openfe/protocols/openmm_afe/equil_solvation_afe_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[dict[str, gufe.ComponentMapping]] = None,
mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None,
extends: Optional[gufe.ProtocolDAGResult] = None,
) -> list[gufe.ProtocolUnit]:
# TODO: extensions
Expand Down
26 changes: 13 additions & 13 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _get_alchemical_charge_difference(

def _validate_alchemical_components(
alchemical_components: dict[str, list[Component]],
mapping: Optional[dict[str, ComponentMapping]],
mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]],
):
"""
Checks that the alchemical components are suitable for the RFE protocol.
Expand All @@ -188,8 +188,8 @@ def _validate_alchemical_components(
alchemical_components : dict[str, list[Component]]
Dictionary contatining the alchemical components for
states A and B.
mapping : dict[str, ComponentMapping]
Dictionary of mappings between transforming components.
mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]]
all mappings between transforming components.
Raises
------
Expand All @@ -201,16 +201,17 @@ def _validate_alchemical_components(
UserWarning
* Mappings which involve element changes in core atoms
"""
if isinstance(mapping, ComponentMapping):
mapping = [mapping]
# Check mapping
# For now we only allow for a single mapping, this will likely change
if mapping is None or len(mapping.values()) > 1:
if mapping is None or len(mapping) != 1:
errmsg = "A single LigandAtomMapping is expected for this Protocol"
raise ValueError(errmsg)

# Check that all alchemical components are mapped & small molecules
mapped = {}
mapped['stateA'] = [m.componentA for m in mapping.values()]
mapped['stateB'] = [m.componentB for m in mapping.values()]
mapped = {'stateA': [m.componentA for m in mapping],
'stateB': [m.componentB for m in mapping]}

for idx in ['stateA', 'stateB']:
if len(alchemical_components[idx]) != len(mapped[idx]):
Expand All @@ -226,7 +227,7 @@ def _validate_alchemical_components(
raise ValueError(errmsg)

# Validate element changes in mappings
for m in mapping.values():
for m in mapping:
molA = m.componentA.to_rdkit()
molB = m.componentB.to_rdkit()
for i, j in m.componentA_to_componentB.items():
Expand Down Expand Up @@ -470,7 +471,7 @@ def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[dict[str, gufe.ComponentMapping]] = None,
mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]],
extends: Optional[gufe.ProtocolDAGResult] = None,
) -> list[gufe.ProtocolUnit]:
# TODO: Extensions?
Expand All @@ -482,9 +483,7 @@ def _create(
stateA, stateB
)
_validate_alchemical_components(alchem_comps, mapping)

# For now we've made it fail already if it was None,
ligandmapping = list(mapping.values())[0] # type: ignore
ligandmapping = mapping[0] if isinstance(mapping, list) else mapping # type: ignore

# Validate solvent component
nonbond = self.settings.forcefield_settings.nonbonded_method
Expand All @@ -500,7 +499,8 @@ def _create(
n_repeats = self.settings.protocol_repeats
units = [RelativeHybridTopologyProtocolUnit(
protocol=self,
stateA=stateA, stateB=stateB, ligandmapping=ligandmapping,
stateA=stateA, stateB=stateB,
ligandmapping=ligandmapping, # type: ignore
generation=0, repeat_id=int(uuid.uuid4()),
name=f'{Anames} to {Bnames} repeat {i} generation 0')
for i in range(n_repeats)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _build_transformation(
return Transformation(
stateA=stateA,
stateB=stateB,
mapping={RFEComponentLabels.LIGAND: ligand_mapping_edge},
mapping=ligand_mapping_edge,
name=transformation_name,
protocol=transformation_protocol,
)
Expand Down
Loading

0 comments on commit 0815897

Please sign in to comment.