Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gather support for parallel dir structure #1044

Merged
merged 20 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions news/support_gather_parallel.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* ``openfe gather`` now supports replicates that have been submitted in parallel across separate directories.
atravitz marked this conversation as resolved.
Show resolved Hide resolved

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
22 changes: 16 additions & 6 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,15 @@ def __init__(self, **data):
if any(len(pur_list) > 2 for pur_list in self.data.values()):
raise NotImplementedError("Can't stitch together results yet")

@staticmethod
def compute_mean_estimate(dGs:list[dict]):
atravitz marked this conversation as resolved.
Show resolved Hide resolved
u = dGs[0].u
# convert all values to units of the first value, then take average of magnitude
# this would avoid a screwy case where each value was in different units
atravitz marked this conversation as resolved.
Show resolved Hide resolved
vals = [dG.to(u).m for dG in dGs]

return np.average(vals) * u
atravitz marked this conversation as resolved.
Show resolved Hide resolved

def get_estimate(self) -> unit.Quantity:
"""Average free energy difference of this transformation

Expand All @@ -267,24 +276,25 @@ def get_estimate(self) -> unit.Quantity:
"""
# TODO: Check this holds up completely for SAMS.
dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()]
return self.compute_mean_estimate(dGs)

@staticmethod
def compute_uncertainty(dGs:list):
atravitz marked this conversation as resolved.
Show resolved Hide resolved
atravitz marked this conversation as resolved.
Show resolved Hide resolved
u = dGs[0].u
# convert all values to units of the first value, then take average of magnitude
# this would avoid a screwy case where each value was in different units
vals = [dG.to(u).m for dG in dGs]

return np.average(vals) * u
return np.std(vals) * u

def get_uncertainty(self) -> unit.Quantity:
"""The uncertainty/error in the dG value: The std of the estimates of
each independent repeat
"""

dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()]
u = dGs[0].u
# convert all values to units of the first value, then take average of magnitude
# this would avoid a screwy case where each value was in different units
vals = [dG.to(u).m for dG in dGs]
return self.compute_uncertainty(dGs)

return np.std(vals) * u

def get_individual_estimates(self) -> list[tuple[unit.Quantity, unit.Quantity]]:
"""Return a list of tuples containing the individual free energy
Expand Down
40 changes: 24 additions & 16 deletions openfecli/commands/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Callable, Literal
import warnings

from openfe.protocols.openmm_rfe.equil_rfe_methods import RelativeHybridTopologyProtocolResult as rfe_result
from openfe.protocols import openmm_rfe
from openfecli import OFECommandPlugin
from openfecli.clicktypes import HyphenAwareChoice

Expand Down Expand Up @@ -200,7 +202,6 @@
pu[0]['outputs']['unit_estimate_error'])
for pu in list_of_pur]


def _get_ddgs(legs:dict, error_on_missing=True):
import numpy as np
DDGs = []
Expand All @@ -215,16 +216,20 @@
do_rhfe = (len(set_vals & {'vacuum', 'solvent'}) == 2)

if do_rbfe:
DG1_mag, DG1_unc = vals['complex']
DG2_mag, DG2_unc = vals['solvent']
DG1_mag = rfe_result.compute_mean_estimate(vals['complex'])
DG1_unc = rfe_result.compute_uncertainty(vals['complex'])
DG2_mag = rfe_result.compute_mean_estimate(vals['solvent'])
DG2_unc = rfe_result.compute_uncertainty(vals['solvent'])
if not ((DG1_mag is None) or (DG2_mag is None)):
# DDG(2,1)bind = DG(1->2)complex - DG(1->2)solvent
DDGbind = (DG1_mag - DG2_mag).m
bind_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m])))

if do_rhfe:
DG1_mag, DG1_unc = vals['solvent']
DG2_mag, DG2_unc = vals['vacuum']
DG1_mag = rfe_result.compute_mean_estimate(vals['solvent'])
DG1_unc = rfe_result.compute_uncertainty(vals['solvent'])
DG2_mag = rfe_result.compute_mean_estimate(vals['vacuum'])
DG2_unc = rfe_result.compute_uncertainty(vals['vacuum'])

Check warning on line 232 in openfecli/commands/gather.py

View check run for this annotation

Codecov / codecov/patch

openfecli/commands/gather.py#L229-L232

Added lines #L229 - L232 were not covered by tests
if not ((DG1_mag is None) or (DG2_mag is None)):
DDGhyd = (DG1_mag - DG2_mag).m
hyd_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m])))
Expand Down Expand Up @@ -258,14 +263,15 @@
writer.writerow(["leg", "ligand_i", "ligand_j",
"DG(i->j) (kcal/mol)", "MBAR uncertainty (kcal/mol)"])

for ligpair, vals in sorted(legs.items()):
for simtype, repeats in sorted(vals.items()):
for m, u in repeats:
if m is None:
m, u = 'NaN', 'NaN'
else:
m, u = format_estimate_uncertainty(m.m, u.m)
writer.writerow([simtype, *ligpair, m, u])
for ligpair, results in sorted(legs.items()):
for simtype, repeats in sorted(results.items()):
for repeat in repeats:
for m, u in repeat:
if m is None:
m, u = 'NaN', 'NaN'

Check warning on line 271 in openfecli/commands/gather.py

View check run for this annotation

Codecov / codecov/patch

openfecli/commands/gather.py#L271

Added line #L271 was not covered by tests
else:
m, u = format_estimate_uncertainty(m.m, u.m)
writer.writerow([simtype, *ligpair, m, u])


def _write_dg_raw(legs:dict, writer:Callable, allow_partial): # pragma: no-cover
Expand Down Expand Up @@ -400,7 +406,7 @@
result_fns = filter(is_results_json, json_fns)

# 3) pair legs of simulations together into dict of dicts
legs = defaultdict(dict)
legs = defaultdict(lambda: defaultdict(list))

for result_fn in result_fns:
result = load_results(result_fn)
Expand All @@ -420,9 +426,11 @@
simtype = legacy_get_type(result_fn)

if report.lower() == 'raw':
legs[names][simtype] = _parse_raw_units(result)
legs[names][simtype].append(_parse_raw_units(result))
else:
legs[names][simtype] = result['estimate'], result['uncertainty']
dGs = [v[0]['outputs']['unit_estimate'] for v in result['protocol_result']['data'].values()]
## for jobs run in parallel, we need to compute these values
legs[names][simtype].extend(dGs)

writer = csv.writer(
output,
Expand Down
46 changes: 23 additions & 23 deletions openfecli/tests/commands/test_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,6 @@ def test_format_estimate_uncertainty(est, unc, unc_prec, est_str, unc_str):
def test_get_column(val, col):
assert _get_column(val) == col


@pytest.fixture
def results_dir_serial(tmpdir):
"""Example output data, with replicates run in serial (3 replicates per results JSON)."""
with tmpdir.as_cwd():
with resources.files('openfecli.tests.data') as d:
t = tarfile.open(d / 'rbfe_results.tar.gz', mode='r')
t.extractall('.')

yield

@pytest.fixture
def results_dir_parallel(tmpdir):
"""Identical data to results_dir_serial(), with replicates run in parallel (1 replicate per results JSON)."""
with tmpdir.as_cwd():
with resources.files('openfecli.tests.data') as d:
t = tarfile.open(d / 'results_parallel.tar.gz', mode='r')
t.extractall('.')

yield

_EXPECTED_DG = b"""
ligand DG(MLE) (kcal/mol) uncertainty (kcal/mol)
lig_ejm_31 -0.09 0.05
Expand Down Expand Up @@ -155,9 +134,29 @@ def results_dir_parallel(tmpdir):
solvent lig_ejm_46 lig_jmc_28 23.4 0.8
"""

@pytest.fixture()
def results_dir_serial(tmpdir):
"""Example output data, with replicates run in serial (3 replicates per results JSON)."""
with tmpdir.as_cwd():
with resources.files('openfecli.tests.data') as d:
t = tarfile.open(d / 'rbfe_results.tar.gz', mode='r')
t.extractall('.')

return os.path.abspath(t.getnames()[0])

@pytest.fixture()
def results_dir_parallel(tmpdir):
"""Example output data, with replicates run in serial (3 replicates per results JSON)."""
with tmpdir.as_cwd():
with resources.files('openfecli.tests.data') as d:
t = tarfile.open(d / 'rbfe_results_parallel.tar.gz', mode='r')
t.extractall('.')

return os.path.abspath(t.getnames()[0])

@pytest.mark.parametrize('data_fixture', ['results_dir_serial', 'results_dir_parallel'])
@pytest.mark.parametrize('report', ["", "dg", "ddg", "raw"])
def test_gather(results_dir_serial, report):
def test_gather(request, data_fixture, report):
expected = {
"": _EXPECTED_DG,
"dg": _EXPECTED_DG,
Expand All @@ -171,7 +170,8 @@ def test_gather(results_dir_serial, report):
else:
args = []

result = runner.invoke(gather, ['results'] + args + ['-o', '-'])
results_dir = request.getfixturevalue(data_fixture)
result = runner.invoke(gather, [results_dir] + args + ['-o', '-'])

assert result.exit_code == 0

Expand Down
Loading