-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
363 additions
and
15 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
""" | ||
Dev script to generate some result jsons that are used for testing | ||
Generates | ||
- ASFEProtocol_json_results.gz | ||
""" | ||
import gzip | ||
import json | ||
import logging | ||
import pathlib | ||
import tempfile | ||
from openff.toolkit import ( | ||
Molecule, RDKitToolkitWrapper, AmberToolsToolkitWrapper | ||
) | ||
from openff.toolkit.utils.toolkit_registry import ( | ||
toolkit_registry_manager, ToolkitRegistry | ||
) | ||
from openff.units import unit | ||
from kartograf.atom_aligner import align_mol_shape | ||
from kartograf import KartografAtomMapper | ||
import gufe | ||
from gufe.tokenization import JSON_HANDLER | ||
import openfe | ||
from pontibus.protocols.solvation import ASFEProtocol | ||
from pontibus.components import ExtendedSolventComponent | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
LIGA = "[H]C([H])([H])C([H])([H])C(=O)C([H])([H])C([H])([H])[H]" | ||
|
||
amber_rdkit = ToolkitRegistry( | ||
[RDKitToolkitWrapper(), AmberToolsToolkitWrapper()] | ||
) | ||
|
||
|
||
def get_molecule(smi, name): | ||
with toolkit_registry_manager(amber_rdkit): | ||
m = Molecule.from_smiles(smi) | ||
m.generate_conformers() | ||
m.assign_partial_charges(partial_charge_method="am1bcc") | ||
return openfe.SmallMoleculeComponent.from_openff(m, name=name) | ||
|
||
|
||
def execute_and_serialize(dag, protocol, simname): | ||
logger.info(f"running {simname}") | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
workdir = pathlib.Path(tmpdir) | ||
dagres = gufe.protocols.execute_DAG( | ||
dag, | ||
shared_basedir=workdir, | ||
scratch_basedir=workdir, | ||
keep_shared=False, | ||
n_retries=3 | ||
) | ||
protres = protocol.gather([dagres]) | ||
|
||
outdict = { | ||
"estimate": protres.get_estimate(), | ||
"uncertainty": protres.get_uncertainty(), | ||
"protocol_result": protres.to_dict(), | ||
"unit_results": { | ||
unit.key: unit.to_keyed_dict() | ||
for unit in dagres.protocol_unit_results | ||
} | ||
} | ||
|
||
with gzip.open(f"{simname}_json_results.gz", 'wt') as zipfile: | ||
json.dump(outdict, zipfile, cls=JSON_HANDLER.encoder) | ||
|
||
|
||
def generate_ahfe_settings(): | ||
settings = ASFEProtocol.default_settings() | ||
settings.solvent_equil_simulation_settings.equilibration_length_nvt = 10 * unit.picosecond | ||
settings.solvent_equil_simulation_settings.equilibration_length = 10 * unit.picosecond | ||
settings.solvent_equil_simulation_settings.production_length = 10 * unit.picosecond | ||
settings.solvent_simulation_settings.equilibration_length = 10 * unit.picosecond | ||
settings.solvent_simulation_settings.production_length = 500 * unit.picosecond | ||
settings.vacuum_equil_simulation_settings.equilibration_length = 10 * unit.picosecond | ||
settings.vacuum_equil_simulation_settings.production_length = 10 * unit.picosecond | ||
settings.vacuum_simulation_settings.equilibration_length = 10 * unit.picosecond | ||
settings.vacuum_simulation_settings.production_length = 500 * unit.picosecond | ||
settings.protocol_repeats = 3 | ||
settings.vacuum_engine_settings.compute_platform = 'CPU' | ||
settings.solvent_engine_settings.compute_platform = 'CUDA' | ||
|
||
return settings | ||
|
||
|
||
def generate_asfe_json(smc): | ||
protocol = ASFEProtocol(settings=generate_ahfe_settings()) | ||
sysA = openfe.ChemicalSystem( | ||
{"ligand": smc, "solvent": ExtendedSolventComponent()} | ||
) | ||
sysB = openfe.ChemicalSystem( | ||
{"solvent": ExtendedSolventComponent()} | ||
) | ||
|
||
dag = protocol.create(stateA=sysA, stateB=sysB, mapping=None) | ||
|
||
execute_and_serialize(dag, protocol, "ASFEProtocol") | ||
|
||
|
||
if __name__ == "__main__": | ||
molA = get_molecule(LIGA, "ligandA") | ||
generate_asfe_json(molA) |
Binary file added
BIN
+46.2 KB
src/pontibus/tests/data/solvation_protocol/ASFEProtocol_json_results.gz
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import gzip | ||
import itertools | ||
import pytest | ||
import json | ||
from importlib import resources | ||
|
||
import numpy as np | ||
from openff.units import unit as offunit | ||
import gufe | ||
import openfe | ||
from pontibus.protocols.solvation import ASFEProtocolResult | ||
|
||
|
||
@pytest.fixture | ||
def afe_solv_transformation_json() -> str: | ||
""" | ||
ASFE results object as created by quickrun. | ||
generated with devtools/gent-serialized-results.py | ||
""" | ||
d = resources.files("pontibus.tests.data.solvation_protocol") | ||
fname = "ASFEProtocol_json_results.gz" | ||
|
||
with gzip.open((d / fname).as_posix(), 'r') as f: | ||
return f.read().decode() | ||
|
||
|
||
class TestProtocolResult: | ||
@pytest.fixture() | ||
def protocolresult(self, afe_solv_transformation_json): | ||
d = json.loads(afe_solv_transformation_json, | ||
cls=gufe.tokenization.JSON_HANDLER.decoder) | ||
|
||
pr = openfe.ProtocolResult.from_dict(d['protocol_result']) | ||
|
||
return pr | ||
|
||
def test_reload_protocol_result(self, afe_solv_transformation_json): | ||
d = json.loads(afe_solv_transformation_json, | ||
cls=gufe.tokenization.JSON_HANDLER.decoder) | ||
|
||
pr = ASFEProtocolResult.from_dict(d['protocol_result']) | ||
|
||
assert pr | ||
|
||
def test_get_estimate(self, protocolresult): | ||
est = protocolresult.get_estimate() | ||
|
||
assert est | ||
assert est.m == pytest.approx(-2.47, abs=0.5) | ||
assert isinstance(est, offunit.Quantity) | ||
assert est.is_compatible_with(offunit.kilojoule_per_mole) | ||
|
||
def test_get_uncertainty(self, protocolresult): | ||
est = protocolresult.get_uncertainty() | ||
|
||
assert est | ||
assert est.m == pytest.approx(0.2, abs=0.2) | ||
assert isinstance(est, offunit.Quantity) | ||
assert est.is_compatible_with(offunit.kilojoule_per_mole) | ||
|
||
def test_get_individual(self, protocolresult): | ||
inds = protocolresult.get_individual_estimates() | ||
|
||
assert isinstance(inds, dict) | ||
assert isinstance(inds['solvent'], list) | ||
assert isinstance(inds['vacuum'], list) | ||
assert len(inds['solvent']) == len(inds['vacuum']) == 3 | ||
for e, u in itertools.chain(inds['solvent'], inds['vacuum']): | ||
assert e.is_compatible_with(offunit.kilojoule_per_mole) | ||
assert u.is_compatible_with(offunit.kilojoule_per_mole) | ||
|
||
@pytest.mark.parametrize('key', ['solvent', 'vacuum']) | ||
def test_get_forwards_etc(self, key, protocolresult): | ||
far = protocolresult.get_forward_and_reverse_energy_analysis() | ||
|
||
assert isinstance(far, dict) | ||
assert isinstance(far[key], list) | ||
far1 = far[key][0] | ||
assert isinstance(far1, dict) | ||
|
||
for k in ['fractions', 'forward_DGs', 'forward_dDGs', | ||
'reverse_DGs', 'reverse_dDGs']: | ||
assert k in far1 | ||
|
||
if k == 'fractions': | ||
assert isinstance(far1[k], np.ndarray) | ||
|
||
@pytest.mark.parametrize('key', ['solvent', 'vacuum']) | ||
def test_get_frwd_reverse_none_return(self, key, protocolresult): | ||
# fetch the first result of type key | ||
data = [i for i in protocolresult.data[key].values()][0][0] | ||
# set the output to None | ||
data.outputs['forward_and_reverse_energies'] = None | ||
|
||
# now fetch the analysis results and expect a warning | ||
wmsg = ("were found in the forward and reverse dictionaries " | ||
f"of the repeats of the {key}") | ||
with pytest.warns(UserWarning, match=wmsg): | ||
protocolresult.get_forward_and_reverse_energy_analysis() | ||
|
||
@pytest.mark.parametrize('key', ['solvent', 'vacuum']) | ||
def test_get_overlap_matrices(self, key, protocolresult): | ||
ovp = protocolresult.get_overlap_matrices() | ||
|
||
assert isinstance(ovp, dict) | ||
assert isinstance(ovp[key], list) | ||
assert len(ovp[key]) == 3 | ||
|
||
ovp1 = ovp[key][0] | ||
assert isinstance(ovp1['matrix'], np.ndarray) | ||
assert ovp1['matrix'].shape == (14, 14) | ||
|
||
@pytest.mark.parametrize('key', ['solvent', 'vacuum']) | ||
def test_get_replica_transition_statistics(self, key, protocolresult): | ||
rpx = protocolresult.get_replica_transition_statistics() | ||
|
||
assert isinstance(rpx, dict) | ||
assert isinstance(rpx[key], list) | ||
assert len(rpx[key]) == 3 | ||
rpx1 = rpx[key][0] | ||
assert 'eigenvalues' in rpx1 | ||
assert 'matrix' in rpx1 | ||
assert rpx1['eigenvalues'].shape == (14,) | ||
assert rpx1['matrix'].shape == (14, 14) | ||
|
||
@pytest.mark.parametrize('key', ['solvent', 'vacuum']) | ||
def test_equilibration_iterations(self, key, protocolresult): | ||
eq = protocolresult.equilibration_iterations() | ||
|
||
assert isinstance(eq, dict) | ||
assert isinstance(eq[key], list) | ||
assert len(eq[key]) == 3 | ||
assert all(isinstance(v, float) for v in eq[key]) | ||
|
||
@pytest.mark.parametrize('key', ['solvent', 'vacuum']) | ||
def test_production_iterations(self, key, protocolresult): | ||
prod = protocolresult.production_iterations() | ||
|
||
assert isinstance(prod, dict) | ||
assert isinstance(prod[key], list) | ||
assert len(prod[key]) == 3 | ||
assert all(isinstance(v, float) for v in prod[key]) | ||
|
||
def test_filenotfound_replica_states(self, protocolresult): | ||
errmsg = "File could not be found" | ||
|
||
with pytest.raises(ValueError, match=errmsg): | ||
protocolresult.get_replica_states() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.