Skip to content

Commit

Permalink
Add tests protocol results
Browse files Browse the repository at this point in the history
  • Loading branch information
hannahbaumann committed Dec 19, 2024
1 parent 2d42497 commit ac9f1a8
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 1 deletion.
Binary file not shown.
14 changes: 14 additions & 0 deletions openfe/tests/protocols/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,20 @@ def md_json() -> str:
return f.read().decode() # type: ignore


@pytest.fixture
def septop_json() -> str:
"""
string of a SepTop result (BACE ligand lig_03 to lig_0) generated by quickrun
generated with gen-serialized-results.py
"""
d = resources.files('openfe.tests.data.openmm_septop')
fname = "SepTopProtocol_json_results.gz"

with gzip.open((d / fname).as_posix(), 'r') as f: # type: ignore
return f.read().decode() # type: ignore


RFE_OUTPUT = pooch.create(
path=pooch.os_cache("openfe_analysis"),
base_url="doi:10.6084/m9.figshare.24101655",
Expand Down
234 changes: 233 additions & 1 deletion openfe/tests/protocols/test_openmm_septop.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import pathlib

import pytest

import openfe.protocols.openmm_septop
from openfe import ChemicalSystem, SolventComponent
from openfe.protocols.openmm_septop import SepTopProtocol
from openfe.protocols.openmm_septop import (
SepTopProtocol,
SepTopComplexSetupUnit,
SepTopSolventSetupUnit,
SepTopProtocolResult,
)
from openfe.protocols.openmm_septop.equil_septop_method import _check_alchemical_charge_difference
from openfe.protocols.openmm_utils import system_validation
import numpy
import openmm
import openmm.app
import openmm.unit
from openff.units import unit as offunit
import gufe
from unittest import mock
import json
import itertools
import numpy as np

from openfe.protocols.openmm_septop.femto_utils import compute_energy, is_close
from openmmtools.alchemy import AlchemicalRegion, AbsoluteAlchemicalFactory
Expand Down Expand Up @@ -596,3 +611,220 @@ def test_two_ligands_charges(self, three_particle_system):
)
expected_energy = energy_fn(0, 2, 1.0, 0.8) + energy_fn(1, 2, 1.0, 0.2)
assert is_close(energy, expected_energy)


@pytest.fixture
def benzene_toluene_dag(benzene_complex_system, toluene_complex_system):
s = SepTopProtocol.default_settings()

protocol = SepTopProtocol(settings=s)

return protocol.create(stateA=benzene_complex_system, stateB=toluene_complex_system, mapping=None)


def test_unit_tagging(benzene_toluene_dag, tmpdir):
# test that executing the units includes correct gen and repeat info

dag_units = benzene_toluene_dag.protocol_units

with (
mock.patch('openfe.protocols.openmm_septop.equil_septop_method.SepTopComplexSetupUnit.run',
return_value={'system': 'system.xml.bz2', 'topology':
'topology.pdb'}),
# mock.patch(
# 'openfe.protocols.openmm_septop.equil_septop_method'
# '.SepTopComplexRunUnit.execute',
# return_value={'nc': 'file.nc', 'last_checkpoint': 'chck.nc'},
# ),
mock.patch(
'openfe.protocols.openmm_septop.equil_septop_method'
'.SepTopSolventSetupUnit.run',
return_value={'system': 'system.xml.bz2', 'topology':
'topology.pdb'}),
# mock.patch(
# 'openfe.protocols.openmm_septop.equil_septop_method'
# '.SepTopSolventRunUnit.execute',
# return_value={'nc': 'file.nc', 'last_checkpoint': 'chck.nc'}),
):
results = []
# For right now only testing the two SetupUnits
#ToDo: Add tests for RunUnits
for u in dag_units[:2]:
ret = u.execute(context=gufe.Context(tmpdir, tmpdir))
results.append(ret)

solv_repeats = set()
complex_repeats = set()
for ret in results:
assert isinstance(ret, gufe.ProtocolUnitResult)
assert ret.outputs['generation'] == 0
if ret.outputs['simtype'] == 'complex':
complex_repeats.add(ret.outputs['repeat_id'])
else:
solv_repeats.add(ret.outputs['repeat_id'])
# Repeat ids are random ints so just check their lengths
assert len(complex_repeats) == len(solv_repeats) == 1


# def test_gather(benzene_toluene_dag, tmpdir):
# # check that .gather behaves as expected
# with (
# mock.patch(
# 'openfe.protocols.openmm_septop.equil_septop_method'
# '.SepTopComplexSetupUnit.run',
# return_value={'system': pathlib.Path('system.xml.bz2'), 'topology':
# 'topology.pdb'}),
# # mock.patch(
# # 'openfe.protocols.openmm_septop.equil_septop_method'
# # '.SepTopComplexRunUnit.execute',
# # return_value={'nc': 'file.nc', 'last_checkpoint': 'chck.nc'},
# # ),
# mock.patch(
# 'openfe.protocols.openmm_septop.equil_septop_method'
# '.SepTopSolventSetupUnit.run',
# return_value={'system': pathlib.Path('system.xml.bz2'), 'topology':
# 'topology.pdb'}),
# # mock.patch(
# # 'openfe.protocols.openmm_septop.equil_septop_method'
# # '.SepTopSolventRunUnit.execute',
# # return_value={'nc': 'file.nc', 'last_checkpoint':
# # 'chck.nc'}),
# ):
# dagres = gufe.protocols.execute_DAG(benzene_toluene_dag,
# shared_basedir=tmpdir,
# scratch_basedir=tmpdir,
# keep_shared=True)
#
# protocol = SepTopProtocol(
# settings=SepTopProtocol.default_settings(),
# )
#
# res = protocol.gather([dagres])
#
# assert isinstance(res, openfe.protocols.openmm_septop.SepTopProtocolResult)


class TestProtocolResult:
@pytest.fixture()
def protocolresult(self, septop_json):
d = json.loads(septop_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)

pr = openfe.ProtocolResult.from_dict(d['protocol_result'])

return pr

def test_reload_protocol_result(self, septop_json):
d = json.loads(septop_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)

pr = SepTopProtocolResult.from_dict(d[
'protocol_result'])

assert pr

def test_get_estimate(self, protocolresult):
est = protocolresult.get_estimate()

assert est
assert est.m == pytest.approx(-3.03, abs=0.5)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)

def test_get_uncertainty(self, protocolresult):
est = protocolresult.get_uncertainty()

assert est.m == pytest.approx(0.0, abs=0.2)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)

def test_get_individual(self, protocolresult):
inds = protocolresult.get_individual_estimates()

assert isinstance(inds, dict)
assert isinstance(inds['solvent'], list)
assert isinstance(inds['complex'], list)
assert len(inds['solvent']) == len(inds['complex']) == 1
for e, u in itertools.chain(inds['solvent'], inds['complex']):
assert e.is_compatible_with(offunit.kilojoule_per_mole)
assert u.is_compatible_with(offunit.kilojoule_per_mole)

#ToDo: Add Results from longer test run that has this analysis

# @pytest.mark.parametrize('key', ['solvent', 'complex'])
# def test_get_forwards_etc(self, key, protocolresult):
# far = protocolresult.get_forward_and_reverse_energy_analysis()
#
# assert isinstance(far, dict)
# assert isinstance(far[key], list)
# far1 = far[key][0]
# assert isinstance(far1, dict)
#
# for k in ['fractions', 'forward_DGs', 'forward_dDGs',
# 'reverse_DGs', 'reverse_dDGs']:
# assert k in far1
#
# if k == 'fractions':
# assert isinstance(far1[k], np.ndarray)
#
# @pytest.mark.parametrize('key', ['solvent', 'complex'])
# def test_get_frwd_reverse_none_return(self, key, protocolresult):
# # fetch the first result of type key
# data = [i for i in protocolresult.data[key].values()][0][0]
# # set the output to None
# data.outputs['forward_and_reverse_energies'] = None
#
# # now fetch the analysis results and expect a warning
# wmsg = ("were found in the forward and reverse dictionaries "
# f"of the repeats of the {key}")
# with pytest.warns(UserWarning, match=wmsg):
# protocolresult.get_forward_and_reverse_energy_analysis()
#
@pytest.mark.parametrize('key', ['solvent', 'complex'])
def test_get_overlap_matrices(self, key, protocolresult):
ovp = protocolresult.get_overlap_matrices()

assert isinstance(ovp, dict)
assert isinstance(ovp[key], list)
assert len(ovp[key]) == 1

ovp1 = ovp[key][0]
assert isinstance(ovp1['matrix'], np.ndarray)
assert ovp1['matrix'].shape == (19, 19)

@pytest.mark.parametrize('key', ['solvent', 'complex'])
def test_get_replica_transition_statistics(self, key, protocolresult):
rpx = protocolresult.get_replica_transition_statistics()

assert isinstance(rpx, dict)
assert isinstance(rpx[key], list)
assert len(rpx[key]) == 1
rpx1 = rpx[key][0]
assert 'eigenvalues' in rpx1
assert 'matrix' in rpx1
assert rpx1['eigenvalues'].shape == (19,)
assert rpx1['matrix'].shape == (19, 19)

@pytest.mark.parametrize('key', ['solvent', 'complex'])
def test_equilibration_iterations(self, key, protocolresult):
eq = protocolresult.equilibration_iterations()

assert isinstance(eq, dict)
assert isinstance(eq[key], list)
assert len(eq[key]) == 1
assert all(isinstance(v, float) for v in eq[key])

@pytest.mark.parametrize('key', ['solvent', 'complex'])
def test_production_iterations(self, key, protocolresult):
prod = protocolresult.production_iterations()

assert isinstance(prod, dict)
assert isinstance(prod[key], list)
assert len(prod[key]) == 1
assert all(isinstance(v, float) for v in prod[key])

def test_filenotfound_replica_states(self, protocolresult):
errmsg = "File could not be found"

with pytest.raises(ValueError, match=errmsg):
protocolresult.get_replica_states()

0 comments on commit ac9f1a8

Please sign in to comment.