diff --git a/gufe/protocols/__init__.py b/gufe/protocols/__init__.py index fca9b580..159ea63e 100644 --- a/gufe/protocols/__init__.py +++ b/gufe/protocols/__init__.py @@ -1,5 +1,13 @@ """Defining processes for performing estimates of free energy differences""" +from .errors import ( + GufeProtocolError, + MissingUnitResultError, + ProtocolDAGResultError, + ProtocolUnitExecutionError, + ProtocolUnitFailureError, + ProtocolValidationError, +) from .protocol import Protocol, ProtocolResult from .protocoldag import ProtocolDAG, ProtocolDAGResult, execute_DAG from .protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult diff --git a/gufe/protocols/errors.py b/gufe/protocols/errors.py new file mode 100644 index 00000000..fef33e77 --- /dev/null +++ b/gufe/protocols/errors.py @@ -0,0 +1,28 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/gufe + + +class GufeProtocolError(Exception): + """The base gufe error that other errors should subclass.""" + + +# Protocol Errors +class ProtocolValidationError(GufeProtocolError): + """Error when the protocol setup or settings can not be validated.""" + + +class ProtocolUnitExecutionError(GufeProtocolError): + """Error when executing a protocol unit.""" + + +# Protocol Results Errors +class ProtocolDAGResultError(GufeProtocolError): + """Base error when dealing with DAG results.""" + + +class MissingUnitResultError(ProtocolDAGResultError): + """Error when a ProtocolDAGResult has no ProtocolUnitResult(s) for a given ProtocolUnit.""" + + +class ProtocolUnitFailureError(ProtocolDAGResultError): + """Error when a ProtocolDAGResult contains a failed protocol unit.""" diff --git a/gufe/protocols/protocoldag.py b/gufe/protocols/protocoldag.py index 74e7afbd..9a23e880 100644 --- a/gufe/protocols/protocoldag.py +++ b/gufe/protocols/protocoldag.py @@ -14,6 +14,7 @@ import networkx as nx from ..tokenization import GufeKey, GufeTokenizable +from .errors import MissingUnitResultError, ProtocolUnitFailureError from .protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult @@ -211,22 +212,24 @@ def unit_to_result(self, protocol_unit: ProtocolUnit) -> ProtocolUnitResult: Raises ------ - KeyError - if either there are no results, or only failures + MissingUnitResultError: + if there are no results for that protocol unit + ProtocolUnitFailureError: + if there are only failures for that protocol unit """ try: units = self._unit_result_mapping[protocol_unit] except KeyError: - raise KeyError("No such `protocol_unit` present") + raise MissingUnitResultError(f"No such `protocol_unit`:{protocol_unit} present") else: for u in units: if u.ok(): return u else: - raise KeyError("No success for `protocol_unit` found") + raise ProtocolUnitFailureError(f"No success for `protocol_unit`:{protocol_unit} found") def unit_to_all_results(self, protocol_unit: ProtocolUnit) -> list[ProtocolUnitResult]: - """Return all results (sucess and failure) for a given Unit. + """Return all results (success and failure) for a given Unit. Returns ------- @@ -235,19 +238,19 @@ def unit_to_all_results(self, protocol_unit: ProtocolUnit) -> list[ProtocolUnitR Raises ------ - KeyError + MissingUnitResultError if no results present for a given unit """ try: return self._unit_result_mapping[protocol_unit] except KeyError: - raise KeyError("No such `protocol_unit` present") + raise MissingUnitResultError(f"No such `protocol_unit`:{protocol_unit} present") def result_to_unit(self, protocol_unit_result: ProtocolUnitResult) -> ProtocolUnit: try: return self._result_unit_mapping[protocol_unit_result] except KeyError: - raise KeyError("No such `protocol_unit_result` present") + raise MissingUnitResultError(f"No such `protocol_unit_result`:{protocol_unit_result} present") def ok(self) -> bool: # ensure that for every protocol unit, there is an OK result object diff --git a/gufe/tests/test_protocol.py b/gufe/tests/test_protocol.py index 50fb8c0f..bd8dc9ea 100644 --- a/gufe/tests/test_protocol.py +++ b/gufe/tests/test_protocol.py @@ -17,12 +17,14 @@ from gufe.chemicalsystem import ChemicalSystem from gufe.mapping import ComponentMapping from gufe.protocols import ( + MissingUnitResultError, Protocol, ProtocolDAG, ProtocolDAGResult, ProtocolResult, ProtocolUnit, ProtocolUnitFailure, + ProtocolUnitFailureError, ProtocolUnitResult, ) from gufe.protocols.protocoldag import execute_DAG @@ -692,7 +694,7 @@ def test_missing_result(self, units, successes, failures): assert not dagresult.ok() - with pytest.raises(KeyError, match="No success for `protocol_unit` found") as e: + with pytest.raises(ProtocolUnitFailureError, match="No success for `protocol_unit`:NoDepUnit\(None\) found"): dagresult.unit_to_result(units[2]) def test_plenty_of_fails(self, units, successes, failures): @@ -721,11 +723,13 @@ def test_foreign_objects(self, units, successes): transformation_key=None, ) - with pytest.raises(KeyError, match="No such `protocol_unit` present"): + with pytest.raises(MissingUnitResultError, match="No such `protocol_unit`:NoDepUnit\(None\) present"): dagresult.unit_to_result(units[2]) - with pytest.raises(KeyError, match="No such `protocol_unit` present"): + with pytest.raises(MissingUnitResultError, match="No such `protocol_unit`:NoDepUnit\(None\) present"): dagresult.unit_to_all_results(units[2]) - with pytest.raises(KeyError, match="No such `protocol_unit_result` present"): + with pytest.raises( + MissingUnitResultError, match="No such `protocol_unit_result`:ProtocolUnitResult\(None\) present" + ): dagresult.result_to_unit(successes[2])