Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

consequences of gufe #260 change #666

Merged
merged 10 commits into from
Feb 8, 2024
2 changes: 1 addition & 1 deletion openfe/protocols/openmm_afe/equil_solvation_afe_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[dict[str, gufe.ComponentMapping]] = None,
mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None,
extends: Optional[gufe.ProtocolDAGResult] = None,
) -> list[gufe.ProtocolUnit]:
# TODO: extensions
Expand Down
26 changes: 13 additions & 13 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _get_alchemical_charge_difference(

def _validate_alchemical_components(
alchemical_components: dict[str, list[Component]],
mapping: Optional[dict[str, ComponentMapping]],
mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]],
):
"""
Checks that the alchemical components are suitable for the RFE protocol.
Expand All @@ -188,8 +188,8 @@ def _validate_alchemical_components(
alchemical_components : dict[str, list[Component]]
Dictionary contatining the alchemical components for
states A and B.
mapping : dict[str, ComponentMapping]
Dictionary of mappings between transforming components.
mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]]
all mappings between transforming components.

Raises
------
Expand All @@ -201,16 +201,17 @@ def _validate_alchemical_components(
UserWarning
* Mappings which involve element changes in core atoms
"""
if isinstance(mapping, ComponentMapping):
mapping = [mapping]
# Check mapping
# For now we only allow for a single mapping, this will likely change
if mapping is None or len(mapping.values()) > 1:
if mapping is None or len(mapping) != 1:
errmsg = "A single LigandAtomMapping is expected for this Protocol"
raise ValueError(errmsg)

# Check that all alchemical components are mapped & small molecules
mapped = {}
mapped['stateA'] = [m.componentA for m in mapping.values()]
mapped['stateB'] = [m.componentB for m in mapping.values()]
mapped = {'stateA': [m.componentA for m in mapping],
'stateB': [m.componentB for m in mapping]}

for idx in ['stateA', 'stateB']:
if len(alchemical_components[idx]) != len(mapped[idx]):
Expand All @@ -226,7 +227,7 @@ def _validate_alchemical_components(
raise ValueError(errmsg)

# Validate element changes in mappings
for m in mapping.values():
for m in mapping:
molA = m.componentA.to_rdkit()
molB = m.componentB.to_rdkit()
for i, j in m.componentA_to_componentB.items():
Expand Down Expand Up @@ -470,7 +471,7 @@ def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[dict[str, gufe.ComponentMapping]] = None,
mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]],
extends: Optional[gufe.ProtocolDAGResult] = None,
) -> list[gufe.ProtocolUnit]:
# TODO: Extensions?
Expand All @@ -482,9 +483,7 @@ def _create(
stateA, stateB
)
_validate_alchemical_components(alchem_comps, mapping)

# For now we've made it fail already if it was None,
ligandmapping = list(mapping.values())[0] # type: ignore
ligandmapping = mapping[0] if isinstance(mapping, list) else mapping # type: ignore

# Validate solvent component
nonbond = self.settings.forcefield_settings.nonbonded_method
Expand All @@ -500,7 +499,8 @@ def _create(
n_repeats = self.settings.protocol_repeats
units = [RelativeHybridTopologyProtocolUnit(
protocol=self,
stateA=stateA, stateB=stateB, ligandmapping=ligandmapping,
stateA=stateA, stateB=stateB,
ligandmapping=ligandmapping, # type: ignore
generation=0, repeat_id=int(uuid.uuid4()),
name=f'{Anames} to {Bnames} repeat {i} generation 0')
for i in range(n_repeats)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _build_transformation(
return Transformation(
stateA=stateA,
stateB=stateB,
mapping={RFEComponentLabels.LIGAND: ligand_mapping_edge},
mapping=ligand_mapping_edge,
name=transformation_name,
protocol=transformation_protocol,
)
Expand Down
Loading
Loading