diff --git a/news/support_gather_parallel.rst b/news/support_gather_parallel.rst new file mode 100644 index 000000000..5b3c00b8f --- /dev/null +++ b/news/support_gather_parallel.rst @@ -0,0 +1,23 @@ +**Added:** + +* ``openfe gather`` now supports replicates that have been submitted in parallel across separate directories. + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 2b80370e0..7bf92050e 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -256,6 +256,15 @@ def __init__(self, **data): if any(len(pur_list) > 2 for pur_list in self.data.values()): raise NotImplementedError("Can't stitch together results yet") + @staticmethod + def compute_mean_estimate(dGs:list[unit.Quantity]): + u = dGs[0].u + # convert all values to units of the first value, then take average of magnitude + # this would avoid a screwy case where each value was in different units + vals = [dG.to(u).m for dG in dGs] + + return np.average(vals) * u + def get_estimate(self) -> unit.Quantity: """Average free energy difference of this transformation @@ -267,24 +276,25 @@ def get_estimate(self) -> unit.Quantity: """ # TODO: Check this holds up completely for SAMS. dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] + return self.compute_mean_estimate(dGs) + + @staticmethod + def compute_uncertainty(dGs:list[unit.Quantity]): u = dGs[0].u # convert all values to units of the first value, then take average of magnitude # this would avoid a screwy case where each value was in different units vals = [dG.to(u).m for dG in dGs] - return np.average(vals) * u + return np.std(vals) * u def get_uncertainty(self) -> unit.Quantity: """The uncertainty/error in the dG value: The std of the estimates of each independent repeat """ + dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] - u = dGs[0].u - # convert all values to units of the first value, then take average of magnitude - # this would avoid a screwy case where each value was in different units - vals = [dG.to(u).m for dG in dGs] + return self.compute_uncertainty(dGs) - return np.std(vals) * u def get_individual_estimates(self) -> list[tuple[unit.Quantity, unit.Quantity]]: """Return a list of tuples containing the individual free energy diff --git a/openfecli/commands/gather.py b/openfecli/commands/gather.py index 52a16a2b2..a3476b400 100644 --- a/openfecli/commands/gather.py +++ b/openfecli/commands/gather.py @@ -7,6 +7,8 @@ from typing import Callable, Literal import warnings +from openfe.protocols.openmm_rfe.equil_rfe_methods import RelativeHybridTopologyProtocolResult as rfe_result +from openfe.protocols import openmm_rfe from openfecli import OFECommandPlugin from openfecli.clicktypes import HyphenAwareChoice @@ -200,7 +202,6 @@ def _parse_raw_units(results: dict) -> list[tuple]: pu[0]['outputs']['unit_estimate_error']) for pu in list_of_pur] - def _get_ddgs(legs:dict, error_on_missing=True): import numpy as np DDGs = [] @@ -215,16 +216,20 @@ def _get_ddgs(legs:dict, error_on_missing=True): do_rhfe = (len(set_vals & {'vacuum', 'solvent'}) == 2) if do_rbfe: - DG1_mag, DG1_unc = vals['complex'] - DG2_mag, DG2_unc = vals['solvent'] + DG1_mag = rfe_result.compute_mean_estimate(vals['complex']) + DG1_unc = rfe_result.compute_uncertainty(vals['complex']) + DG2_mag = rfe_result.compute_mean_estimate(vals['solvent']) + DG2_unc = rfe_result.compute_uncertainty(vals['solvent']) if not ((DG1_mag is None) or (DG2_mag is None)): # DDG(2,1)bind = DG(1->2)complex - DG(1->2)solvent DDGbind = (DG1_mag - DG2_mag).m bind_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m]))) if do_rhfe: - DG1_mag, DG1_unc = vals['solvent'] - DG2_mag, DG2_unc = vals['vacuum'] + DG1_mag = rfe_result.compute_mean_estimate(vals['solvent']) + DG1_unc = rfe_result.compute_uncertainty(vals['solvent']) + DG2_mag = rfe_result.compute_mean_estimate(vals['vacuum']) + DG2_unc = rfe_result.compute_uncertainty(vals['vacuum']) if not ((DG1_mag is None) or (DG2_mag is None)): DDGhyd = (DG1_mag - DG2_mag).m hyd_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m]))) @@ -258,14 +263,15 @@ def _write_raw(legs:dict, writer:Callable, allow_partial=True): writer.writerow(["leg", "ligand_i", "ligand_j", "DG(i->j) (kcal/mol)", "MBAR uncertainty (kcal/mol)"]) - for ligpair, vals in sorted(legs.items()): - for simtype, repeats in sorted(vals.items()): - for m, u in repeats: - if m is None: - m, u = 'NaN', 'NaN' - else: - m, u = format_estimate_uncertainty(m.m, u.m) - writer.writerow([simtype, *ligpair, m, u]) + for ligpair, results in sorted(legs.items()): + for simtype, repeats in sorted(results.items()): + for repeat in repeats: + for m, u in repeat: + if m is None: + m, u = 'NaN', 'NaN' + else: + m, u = format_estimate_uncertainty(m.m, u.m) + writer.writerow([simtype, *ligpair, m, u]) def _write_dg_raw(legs:dict, writer:Callable, allow_partial): # pragma: no-cover @@ -400,7 +406,7 @@ def gather(rootdir:os.PathLike|str, result_fns = filter(is_results_json, json_fns) # 3) pair legs of simulations together into dict of dicts - legs = defaultdict(dict) + legs = defaultdict(lambda: defaultdict(list)) for result_fn in result_fns: result = load_results(result_fn) @@ -420,9 +426,11 @@ def gather(rootdir:os.PathLike|str, simtype = legacy_get_type(result_fn) if report.lower() == 'raw': - legs[names][simtype] = _parse_raw_units(result) + legs[names][simtype].append(_parse_raw_units(result)) else: - legs[names][simtype] = result['estimate'], result['uncertainty'] + dGs = [v[0]['outputs']['unit_estimate'] for v in result['protocol_result']['data'].values()] + ## for jobs run in parallel, we need to compute these values + legs[names][simtype].extend(dGs) writer = csv.writer( output, diff --git a/openfecli/tests/commands/test_gather.py b/openfecli/tests/commands/test_gather.py index 6508763cf..8949e48ea 100644 --- a/openfecli/tests/commands/test_gather.py +++ b/openfecli/tests/commands/test_gather.py @@ -26,27 +26,6 @@ def test_format_estimate_uncertainty(est, unc, unc_prec, est_str, unc_str): def test_get_column(val, col): assert _get_column(val) == col - -@pytest.fixture -def results_dir_serial(tmpdir): - """Example output data, with replicates run in serial (3 replicates per results JSON).""" - with tmpdir.as_cwd(): - with resources.files('openfecli.tests.data') as d: - t = tarfile.open(d / 'rbfe_results.tar.gz', mode='r') - t.extractall('.') - - yield - -@pytest.fixture -def results_dir_parallel(tmpdir): - """Identical data to results_dir_serial(), with replicates run in parallel (1 replicate per results JSON).""" - with tmpdir.as_cwd(): - with resources.files('openfecli.tests.data') as d: - t = tarfile.open(d / 'results_parallel.tar.gz', mode='r') - t.extractall('.') - - yield - _EXPECTED_DG = b""" ligand DG(MLE) (kcal/mol) uncertainty (kcal/mol) lig_ejm_31 -0.09 0.05 @@ -155,9 +134,29 @@ def results_dir_parallel(tmpdir): solvent lig_ejm_46 lig_jmc_28 23.4 0.8 """ +@pytest.fixture() +def results_dir_serial(tmpdir): + """Example output data, with replicates run in serial (3 replicates per results JSON).""" + with tmpdir.as_cwd(): + with resources.files('openfecli.tests.data') as d: + t = tarfile.open(d / 'rbfe_results.tar.gz', mode='r') + t.extractall('.') + + return os.path.abspath(t.getnames()[0]) + +@pytest.fixture() +def results_dir_parallel(tmpdir): + """Example output data, with replicates run in serial (3 replicates per results JSON).""" + with tmpdir.as_cwd(): + with resources.files('openfecli.tests.data') as d: + t = tarfile.open(d / 'rbfe_results_parallel.tar.gz', mode='r') + t.extractall('.') + + return os.path.abspath(t.getnames()[0]) +@pytest.mark.parametrize('data_fixture', ['results_dir_serial', 'results_dir_parallel']) @pytest.mark.parametrize('report', ["", "dg", "ddg", "raw"]) -def test_gather(results_dir_serial, report): +def test_gather(request, data_fixture, report): expected = { "": _EXPECTED_DG, "dg": _EXPECTED_DG, @@ -171,7 +170,8 @@ def test_gather(results_dir_serial, report): else: args = [] - result = runner.invoke(gather, ['results'] + args + ['-o', '-']) + results_dir = request.getfixturevalue(data_fixture) + result = runner.invoke(gather, [results_dir] + args + ['-o', '-']) assert result.exit_code == 0 diff --git a/openfecli/tests/data/results_parallel.tar.gz b/openfecli/tests/data/rbfe_results_parallel.tar.gz similarity index 100% rename from openfecli/tests/data/results_parallel.tar.gz rename to openfecli/tests/data/rbfe_results_parallel.tar.gz