diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 00000000..e66d5736 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# https://github.com/OpenFreeEnergy/gufe/pull/421 -- big auto format MMH +d27100a5b7b303df155e2b6d7874883d105e24bf diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..d34f4787 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,22 @@ + + + + + +Tips +* Comment "pre-commit.ci autofix" to have pre-commit.ci atomically format your PR. + Since this will create a commit, it is best to make this comment when you are finished with your work. + + +Checklist +* [ ] Added a ``news`` entry + +## Developers certificate of origin +- [ ] I certify that this contribution is covered by the MIT License [here](https://github.com/OpenFreeEnergy/openfe/blob/main/LICENSE) and the **Developer Certificate of Origin** at . diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 1914b712..6df2ffd6 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,18 +27,15 @@ jobs: strategy: fail-fast: false matrix: - os: ['ubuntu-latest'] + os: ['ubuntu-latest', macos-latest] pydantic-version: [">1"] python-version: - "3.10" - - "3.11" + - "3.11" - "3.12" include: - # Note: pinned to macos-12 + # Note: we still need to add support for macos-13 # see https://github.com/OpenFreeEnergy/openfe/issues/842 - - os: "macos-12" - python-version: "3.11" - pydantic-version: ">1" - os: "ubuntu-latest" python-version: "3.11" pydantic-version: "<2" diff --git a/.github/workflows/clean_cache.yaml b/.github/workflows/clean_cache.yaml index c3db0a6f..e9a6c50d 100644 --- a/.github/workflows/clean_cache.yaml +++ b/.github/workflows/clean_cache.yaml @@ -11,18 +11,18 @@ jobs: steps: - name: Check out code uses: actions/checkout@v3 - + - name: Cleanup run: | gh extension install actions/gh-actions-cache - + REPO=${{ github.repository }} BRANCH="refs/pull/${{ github.event.pull_request.number }}/merge" echo "Fetching list of cache key" cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH | cut -f 1 ) - ## Setting this to not fail the workflow while deleting cache keys. + ## Setting this to not fail the workflow while deleting cache keys. set +e echo "Deleting caches..." for cacheKey in $cacheKeysForPR diff --git a/.github/workflows/conda_cron.yaml b/.github/workflows/conda_cron.yaml index 33d3a954..c4f7ec49 100644 --- a/.github/workflows/conda_cron.yaml +++ b/.github/workflows/conda_cron.yaml @@ -24,7 +24,7 @@ jobs: strategy: fail-fast: false matrix: - os: ['ubuntu-latest', 'macos-14'] + os: ['ubuntu-latest', 'macos-latest'] python-version: - "3.10" - "3.11" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..7b469f8c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,41 @@ +ci: + autoupdate_schedule: "quarterly" + # comment / label "pre-commit.ci autofix" to a pull request to manually trigger auto-fixing + autofix_prs: false +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-added-large-files + - id: check-case-conflict + - id: check-executables-have-shebangs + - id: check-symlinks + - id: check-toml + - id: check-yaml + - id: debug-statements + - id: destroyed-symlinks + - id: end-of-file-fixer + exclude: '\.(graphml)$' + - id: trailing-whitespace + exclude: '\.(pdb|gro|top|sdf|xml|cif|graphml)$' +- repo: https://github.com/psf/black-pre-commit-mirror + rev: 24.2.0 + hooks: + - id: black + - id: black-jupyter +- repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] +- repo: https://github.com/econchick/interrogate + rev: 1.5.0 + hooks: + - id: interrogate + args: [--fail-under=28] + pass_filenames: false +- repo: https://github.com/asottile/pyupgrade + rev: v3.15.0 + hooks: + - id: pyupgrade + args: ["--py39-plus"] diff --git a/README.md b/README.md index 929185a4..5fc2c7e2 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![build](https://github.com/OpenFreeEnergy/gufe/actions/workflows/ci.yaml/badge.svg)](https://github.com/OpenFreeEnergy/gufe/actions/workflows/ci.yaml) [![coverage](https://codecov.io/gh/OpenFreeEnergy/gufe/branch/main/graph/badge.svg)](https://codecov.io/gh/OpenFreeEnergy/gufe) [![Documentation Status](https://readthedocs.org/projects/gufe/badge/?version=latest)](https://gufe.readthedocs.io/en/latest/?badge=latest) +[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/OpenFreeEnergy/gufe/main.svg)](https://results.pre-commit.ci/latest/github/OpenFreeEnergy/gufe/main) # gufe - Grand Unified Free Energy diff --git a/devtools/raise-or-close-issue.py b/devtools/raise-or-close-issue.py index ff9d7742..97406337 100644 --- a/devtools/raise-or-close-issue.py +++ b/devtools/raise-or-close-issue.py @@ -5,24 +5,24 @@ # - TITLE: A string title which the issue will have. import os -from github import Github +from github import Github if __name__ == "__main__": - git = Github(os.environ['GITHUB_TOKEN']) - status = os.environ['CI_OUTCOME'] - repo = git.get_repo('OpenFreeEnergy/gufe') - title = os.environ['TITLE'] - + git = Github(os.environ["GITHUB_TOKEN"]) + status = os.environ["CI_OUTCOME"] + repo = git.get_repo("OpenFreeEnergy/gufe") + title = os.environ["TITLE"] + target_issue = None for issue in repo.get_issues(): if issue.title == title: target_issue = issue - + # Close any issues with given title if CI returned green - if status == 'success': + if status == "success": if target_issue is not None: - target_issue.edit(state='closed') + target_issue.edit(state="closed") else: # Otherwise raise an issue if target_issue is None: diff --git a/docs/CHANGELOG.rst b/docs/CHANGELOG.rst index a50a0eda..2659e23d 100644 --- a/docs/CHANGELOG.rst +++ b/docs/CHANGELOG.rst @@ -28,5 +28,3 @@ v1.1.0 * Fixed an issue where ProtocolDAG DAG order & keys were unstable / non-deterministic between processes under some circumstances (PR #315). * Fixed a bug where edge annotations were lost when converting a ``LigandNetwork`` to graphml, all JSON codec types are now supported. - - diff --git a/docs/_static/.gitkeep b/docs/_static/.gitkeep index 8b137891..e69de29b 100644 --- a/docs/_static/.gitkeep +++ b/docs/_static/.gitkeep @@ -1 +0,0 @@ - diff --git a/docs/conf.py b/docs/conf.py index 905e4840..62d8bbdb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,14 +12,15 @@ # import os import sys -sys.path.insert(0, os.path.abspath('../')) + +sys.path.insert(0, os.path.abspath("../")) # -- Project information ----------------------------------------------------- -project = 'gufe' -copyright = '2022, The OpenFE Development Team' -author = 'The OpenFE Development Team' +project = "gufe" +copyright = "2022, The OpenFE Development Team" +author = "The OpenFE Development Team" # -- General configuration --------------------------------------------------- @@ -28,14 +29,14 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.napoleon', - 'sphinxcontrib.autodoc_pydantic', - 'sphinx.ext.intersphinx', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinxcontrib.autodoc_pydantic", + "sphinx.ext.intersphinx", ] -autoclass_content = 'both' +autoclass_content = "both" autodoc_default_options = { "members": True, @@ -47,20 +48,21 @@ autosummary_generate = True intersphinx_mapping = { - 'rdkit': ('https://www.rdkit.org/docs/', None), + "rdkit": ("https://www.rdkit.org/docs/", None), } # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] -autodoc_mock_imports = ["openff.models", - "rdkit", - "networkx", +autodoc_mock_imports = [ + "openff.models", + "rdkit", + "networkx", ] @@ -77,7 +79,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # replace macros rst_prolog = """ diff --git a/docs/guide/overview.rst b/docs/guide/overview.rst index 30b9fd5f..86086b63 100644 --- a/docs/guide/overview.rst +++ b/docs/guide/overview.rst @@ -25,7 +25,7 @@ Ligand network setup GUFE defines a basic API for the common case of performing alchemical transformations between small molecules, either for relative binding free energies of relative hydration free energies. This handles how mappings -between different molecules are defined for alchemical transformations, +between different molecules are defined for alchemical transformations, by defining both the :class:`.LigandAtomMapping` object that contains the details of a specific mapping, and the :class:`.AtomMapper` abstract API for an object that creates the mappings. diff --git a/docs/guide/serialization.rst b/docs/guide/serialization.rst index 3e81c9e9..c983233d 100644 --- a/docs/guide/serialization.rst +++ b/docs/guide/serialization.rst @@ -245,7 +245,7 @@ classmethod: .. Using JSON codecs outside of JSON .. --------------------------------- -.. In a custom recursive storage scheme, +.. In a custom recursive storage scheme, .. TODO: DWHS wants to write something here that describes how to use the codecs in your own non-JSON storage scheme. But this is complicated diff --git a/gufe/__init__.py b/gufe/__init__.py index 10e12d57..c943abf9 100644 --- a/gufe/__init__.py +++ b/gufe/__init__.py @@ -3,40 +3,21 @@ from importlib.metadata import version -from . import tokenization - -from . import visualization - -from .components import ( - Component, - SmallMoleculeComponent, - ProteinComponent, - SolventComponent -) - +from . import tokenization, visualization from .chemicalsystem import ChemicalSystem - -from .mapping import ( - ComponentMapping, # how individual Components relate - AtomMapping, AtomMapper, # more specific to atom based components - LigandAtomMapping, -) - -from .settings import Settings - -from .protocols import ( - Context, - Protocol, # description of a method - ProtocolUnit, # the individual step within a method - ProtocolDAG, # many Units forming a workflow - ProtocolUnitResult, # the result of a single Unit - ProtocolDAGResult, # the collected result of a DAG - ProtocolResult, # potentially many DAGs together, giving an estimate -) - -from .transformations import Transformation, NonTransformation - -from .network import AlchemicalNetwork +from .components import Component, ProteinComponent, SmallMoleculeComponent, SolventComponent from .ligandnetwork import LigandNetwork +from .mapping import AtomMapper # more specific to atom based components +from .mapping import ComponentMapping # how individual Components relate +from .mapping import AtomMapping, LigandAtomMapping +from .network import AlchemicalNetwork +from .protocols import Protocol # description of a method +from .protocols import ProtocolDAG # many Units forming a workflow +from .protocols import ProtocolDAGResult # the collected result of a DAG +from .protocols import ProtocolUnit # the individual step within a method +from .protocols import ProtocolUnitResult # the result of a single Unit +from .protocols import Context, ProtocolResult # potentially many DAGs together, giving an estimate +from .settings import Settings +from .transformations import NonTransformation, Transformation __version__ = version("gufe") diff --git a/gufe/chemicalsystem.py b/gufe/chemicalsystem.py index dedc5619..c4eb4add 100644 --- a/gufe/chemicalsystem.py +++ b/gufe/chemicalsystem.py @@ -4,8 +4,8 @@ from collections import abc from typing import Optional -from .tokenization import GufeTokenizable from .components import Component +from .tokenization import GufeTokenizable class ChemicalSystem(GufeTokenizable, abc.Mapping): @@ -14,14 +14,14 @@ def __init__( components: dict[str, Component], name: Optional[str] = "", ): - """A combination of Components that form a system + r"""A combination of Components that form a system Containing a combination of :class:`.SmallMoleculeComponent`, :class:`.SolventComponent` and :class:`.ProteinComponent`, this object typically represents all the molecules in a simulation box. Used as a node for an :class:`.AlchemicalNetwork`. - + Parameters ---------- components @@ -40,24 +40,18 @@ def __init__( self._name = name def __repr__(self): - return ( - f"{self.__class__.__name__}(name={self.name}, components={self.components})" - ) + return f"{self.__class__.__name__}(name={self.name}, components={self.components})" def _to_dict(self): return { - "components": { - key: value for key, value in sorted(self.components.items()) - }, + "components": {key: value for key, value in sorted(self.components.items())}, "name": self.name, } @classmethod def _from_dict(cls, d): return cls( - components={ - key: value for key, value in d["components"].items() - }, + components={key: value for key, value in d["components"].items()}, name=d["name"], ) @@ -86,7 +80,7 @@ def name(self): def total_charge(self): """Formal charge for the ChemicalSystem.""" # This might evaluate the property twice? - #return sum(component.total_charge + # return sum(component.total_charge # for component in self._components.values() # if component.total_charge is not None) total_charge = 0 diff --git a/gufe/components/__init__.py b/gufe/components/__init__.py index 6713db21..0f9a69ea 100644 --- a/gufe/components/__init__.py +++ b/gufe/components/__init__.py @@ -1,7 +1,6 @@ """The building blocks for defining systems""" from .component import Component - -from .smallmoleculecomponent import SmallMoleculeComponent from .proteincomponent import ProteinComponent +from .smallmoleculecomponent import SmallMoleculeComponent from .solventcomponent import SolventComponent diff --git a/gufe/components/explicitmoleculecomponent.py b/gufe/components/explicitmoleculecomponent.py index a3a513e5..88f1f196 100644 --- a/gufe/components/explicitmoleculecomponent.py +++ b/gufe/components/explicitmoleculecomponent.py @@ -1,13 +1,13 @@ import json -import numpy as np import warnings -from rdkit import Chem from typing import Optional -from .component import Component +import numpy as np +from rdkit import Chem # typing from ..custom_typing import RDKitMol +from .component import Component def _ensure_ofe_name(mol: RDKitMol, name: str) -> str: @@ -26,10 +26,7 @@ def _ensure_ofe_name(mol: RDKitMol, name: str) -> str: pass if name and rdkit_name and rdkit_name != name: - warnings.warn( - f"Component being renamed from {rdkit_name}" - f"to {name}." - ) + warnings.warn(f"Component being renamed from {rdkit_name}" f"to {name}.") elif name == "": name = rdkit_name @@ -56,21 +53,19 @@ def _check_partial_charges(mol: RDKitMol) -> None: * If partial charges are found. * If the partial charges are near 0 for all atoms. """ - if 'atom.dprop.PartialCharge' not in mol.GetPropNames(): + if "atom.dprop.PartialCharge" not in mol.GetPropNames(): return - p_chgs = np.array( - mol.GetProp('atom.dprop.PartialCharge').split(), dtype=float - ) + p_chgs = np.array(mol.GetProp("atom.dprop.PartialCharge").split(), dtype=float) if len(p_chgs) != mol.GetNumAtoms(): - errmsg = (f"Incorrect number of partial charges: {len(p_chgs)} " - f" were provided for {mol.GetNumAtoms()} atoms") + errmsg = f"Incorrect number of partial charges: {len(p_chgs)} " f" were provided for {mol.GetNumAtoms()} atoms" raise ValueError(errmsg) if (sum(p_chgs) - Chem.GetFormalCharge(mol)) > 0.01: - errmsg = (f"Sum of partial charges {sum(p_chgs)} differs from " - f"RDKit formal charge {Chem.GetFormalCharge(mol)}") + errmsg = ( + f"Sum of partial charges {sum(p_chgs)} differs from " f"RDKit formal charge {Chem.GetFormalCharge(mol)}" + ) raise ValueError(errmsg) # set the charges on the atoms if not already set @@ -81,18 +76,20 @@ def _check_partial_charges(mol: RDKitMol) -> None: else: atom_charge = atom.GetDoubleProp("PartialCharge") if not np.isclose(atom_charge, charge): - errmsg = (f"non-equivalent partial charges between atom and " - f"molecule properties: {atom_charge} {charge}") + errmsg = ( + f"non-equivalent partial charges between atom and " f"molecule properties: {atom_charge} {charge}" + ) raise ValueError(errmsg) if np.all(np.isclose(p_chgs, 0.0)): - wmsg = (f"Partial charges provided all equal to " - "zero. These may be ignored by some Protocols.") + wmsg = f"Partial charges provided all equal to " "zero. These may be ignored by some Protocols." warnings.warn(wmsg) else: - wmsg = ("Partial charges have been provided, these will " - "preferentially be used instead of generating new " - "partial charges") + wmsg = ( + "Partial charges have been provided, these will " + "preferentially be used instead of generating new " + "partial charges" + ) warnings.warn(wmsg) @@ -103,6 +100,7 @@ class ExplicitMoleculeComponent(Component): representations. Specific file formats, such as SDF for small molecules or PDB for proteins, should be implemented in subclasses. """ + _rdkit: Chem.Mol _name: str @@ -115,14 +113,13 @@ def __init__(self, rdkit: RDKitMol, name: str = ""): n_confs = len(conformers) if n_confs > 1: - warnings.warn( - f"Molecule provided with {n_confs} conformers. " - f"Only the first will be used." - ) + warnings.warn(f"Molecule provided with {n_confs} conformers. " f"Only the first will be used.") if not any(atom.GetAtomicNum() == 1 for atom in rdkit.GetAtoms()): - warnings.warn("Molecule doesn't have any hydrogen atoms present. " - "If this is unexpected, consider loading the molecule with `removeHs=False`") + warnings.warn( + "Molecule doesn't have any hydrogen atoms present. " + "If this is unexpected, consider loading the molecule with `removeHs=False`" + ) self._rdkit = rdkit self._smiles: Optional[str] = None diff --git a/gufe/components/proteincomponent.py b/gufe/components/proteincomponent.py index fce2f65c..a23ccf09 100644 --- a/gufe/components/proteincomponent.py +++ b/gufe/components/proteincomponent.py @@ -1,28 +1,24 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe import ast -import json import io -import numpy as np -from os import PathLike +import json import string -from typing import Union, Optional from collections import defaultdict +from os import PathLike +from typing import Optional, Union +import numpy as np from openmm import app from openmm import unit as omm_unit - from rdkit import Chem -from rdkit.Chem.rdchem import Mol, Atom, Conformer, EditableMol, BondType +from rdkit.Chem.rdchem import Atom, BondType, Conformer, EditableMol, Mol from ..custom_typing import RDKitMol -from .explicitmoleculecomponent import ExplicitMoleculeComponent +from ..molhashing import deserialize_numpy, serialize_numpy from ..vendor.pdb_file.pdbfile import PDBFile from ..vendor.pdb_file.pdbxfile import PDBxFile - - -from ..molhashing import deserialize_numpy, serialize_numpy - +from .explicitmoleculecomponent import ExplicitMoleculeComponent _BONDORDERS_OPENMM_TO_RDKIT = { 1: BondType.SINGLE, @@ -34,9 +30,7 @@ app.Aromatic: BondType.AROMATIC, None: BondType.UNSPECIFIED, } -_BONDORDERS_RDKIT_TO_OPENMM = { - v: k for k, v in _BONDORDERS_OPENMM_TO_RDKIT.items() -} +_BONDORDERS_RDKIT_TO_OPENMM = {v: k for k, v in _BONDORDERS_OPENMM_TO_RDKIT.items()} _BONDORDER_TO_ORDER = { BondType.UNSPECIFIED: 1, # assumption BondType.SINGLE: 1, @@ -50,21 +44,29 @@ _BONDORDER_RDKIT_TO_STR = {v: k for k, v in _BONDORDER_STR_TO_RDKIT.items()} _CHIRALITY_RDKIT_TO_STR = { - Chem.CHI_TETRAHEDRAL_CW: 'CW', - Chem.CHI_TETRAHEDRAL_CCW: 'CCW', - Chem.CHI_UNSPECIFIED: 'U', -} -_CHIRALITY_STR_TO_RDKIT = { - v: k for k, v in _CHIRALITY_RDKIT_TO_STR.items() + Chem.CHI_TETRAHEDRAL_CW: "CW", + Chem.CHI_TETRAHEDRAL_CCW: "CCW", + Chem.CHI_UNSPECIFIED: "U", } +_CHIRALITY_STR_TO_RDKIT = {v: k for k, v in _CHIRALITY_RDKIT_TO_STR.items()} negative_ions = ["F", "CL", "BR", "I"] positive_ions = [ # +1 - "LI", "NA", "K", "RB", "CS", + "LI", + "NA", + "K", + "RB", + "CS", # +2 - "BE", "MG", "CA", "SR", "BA", "RA", "ZN", + "BE", + "MG", + "CA", + "SR", + "BA", + "RA", + "ZN", ] @@ -91,10 +93,13 @@ class ProteinComponent(ExplicitMoleculeComponent): edit the molecule do this in an appropriate toolkit **before** creating an instance from this class. """ + def __init__(self, rdkit: RDKitMol, name=""): if not all(a.GetMonomerInfo() is not None for a in rdkit.GetAtoms()): - raise TypeError("Not all atoms in input have MonomerInfo defined. " - "Consider loading via rdkit.Chem.MolFromPDBFile or similar.") + raise TypeError( + "Not all atoms in input have MonomerInfo defined. " + "Consider loading via rdkit.Chem.MolFromPDBFile or similar." + ) super().__init__(rdkit=rdkit, name=name) # FROM @@ -116,9 +121,7 @@ def from_pdb_file(cls, pdb_file: str, name: str = ""): the deserialized molecule """ openmm_PDBFile = PDBFile(pdb_file) - return cls._from_openmmPDBFile( - openmm_PDBFile=openmm_PDBFile, name=name - ) + return cls._from_openmmPDBFile(openmm_PDBFile=openmm_PDBFile, name=name) @classmethod def from_pdbx_file(cls, pdbx_file: str, name=""): @@ -138,13 +141,10 @@ def from_pdbx_file(cls, pdbx_file: str, name=""): the deserialized molecule """ openmm_PDBxFile = PDBxFile(pdbx_file) - return cls._from_openmmPDBFile( - openmm_PDBFile=openmm_PDBxFile, name=name - ) + return cls._from_openmmPDBFile(openmm_PDBFile=openmm_PDBxFile, name=name) @classmethod - def _from_openmmPDBFile(cls, openmm_PDBFile: Union[PDBFile, PDBxFile], - name: str = ""): + def _from_openmmPDBFile(cls, openmm_PDBFile: Union[PDBFile, PDBxFile], name: str = ""): """Converts to our internal representation (rdkit Mol) Parameters @@ -199,9 +199,7 @@ def _from_openmmPDBFile(cls, openmm_PDBFile: Union[PDBFile, PDBxFile], # Set Positions rd_mol = editable_rdmol.GetMol() - positions = np.array( - openmm_PDBFile.positions.value_in_unit(omm_unit.angstrom), ndmin=3 - ) + positions = np.array(openmm_PDBFile.positions.value_in_unit(omm_unit.angstrom), ndmin=3) for frame_id, frame in enumerate(positions): conf = Conformer(frame_id) @@ -216,18 +214,15 @@ def _from_openmmPDBFile(cls, openmm_PDBFile: Union[PDBFile, PDBxFile], atomic_num = a.GetAtomicNum() atom_name = a.GetMonomerInfo().GetName() - connectivity = sum( - _BONDORDER_TO_ORDER[bond.GetBondType()] - for bond in a.GetBonds() - ) + connectivity = sum(_BONDORDER_TO_ORDER[bond.GetBondType()] for bond in a.GetBonds()) default_valence = periodicTable.GetDefaultValence(atomic_num) if connectivity == 0: # ions: # strip catches cases like 'CL1' as name if atom_name.strip(string.digits).upper() in positive_ions: - fc = default_valence # e.g. Sodium ions + fc = default_valence # e.g. Sodium ions elif atom_name.strip(string.digits).upper() in negative_ions: - fc = - default_valence # e.g. Chlorine ions + fc = -default_valence # e.g. Chlorine ions else: # -no-cov- resn = a.GetMonomerInfo().GetResidueName() resind = int(a.GetMonomerInfo().GetResidueNumber()) @@ -237,9 +232,9 @@ def _from_openmmPDBFile(cls, openmm_PDBFile: Union[PDBFile, PDBxFile], f"connectivity{connectivity}" ) elif default_valence > connectivity: - fc = - (default_valence - connectivity) # negative charge + fc = -(default_valence - connectivity) # negative charge elif default_valence < connectivity: - fc = + (connectivity - default_valence) # positive charge + fc = +(connectivity - default_valence) # positive charge else: fc = 0 # neutral @@ -272,7 +267,7 @@ def _from_dict(cls, ser_dict: dict, name: str = ""): mi.SetName(atom[5]) mi.SetResidueName(atom[6]) mi.SetResidueNumber(int(atom[7])) - mi.SetIsHeteroAtom(atom[8] == 'Y') + mi.SetIsHeteroAtom(atom[8] == "Y") a.SetFormalCharge(atom[9]) a.SetMonomerInfo(mi) @@ -304,7 +299,7 @@ def _from_dict(cls, ser_dict: dict, name: str = ""): for bond_id, bond in enumerate(rd_mol.GetBonds()): # Can't set these on an editable mol, go round a second time _, _, _, arom = ser_dict["bonds"][bond_id] - bond.SetIsAromatic(arom == 'Y') + bond.SetIsAromatic(arom == "Y") if "name" in ser_dict: name = ser_dict["name"] @@ -319,6 +314,7 @@ def to_openmm_topology(self) -> app.Topology: openmm.app.Topology resulting topology obj. """ + def reskey(m): """key for defining when a residue has changed from previous @@ -329,7 +325,7 @@ def reskey(m): m.GetChainId(), m.GetResidueName(), m.GetResidueNumber(), - m.GetInsertionCode() + m.GetInsertionCode(), ) def chainkey(m): @@ -362,10 +358,7 @@ def chainkey(m): if (new_resid := reskey(mi)) != current_resid: _, resname, resnum, icode = new_resid - r = top.addResidue(name=resname, - chain=c, - id=str(resnum), - insertionCode=icode) + r = top.addResidue(name=resname, chain=c, id=str(resnum), insertionCode=icode) current_resid = new_resid a = top.addAtom( @@ -380,9 +373,7 @@ def chainkey(m): for bond in self._rdkit.GetBonds(): a1 = atom_lookup[bond.GetBeginAtomIdx()] a2 = atom_lookup[bond.GetEndAtomIdx()] - top.addBond(a1, a2, - order=_BONDORDERS_RDKIT_TO_OPENMM.get( - bond.GetBondType(), None)) + top.addBond(a1, a2, order=_BONDORDERS_RDKIT_TO_OPENMM.get(bond.GetBondType(), None)) return top @@ -400,9 +391,7 @@ def to_openmm_positions(self) -> omm_unit.Quantity: Quantity containing protein atom positions """ np_pos = deserialize_numpy(self.to_dict()["conformers"][0]) - openmm_pos = ( - list(map(lambda x: np.array(x), np_pos)) * omm_unit.angstrom - ) + openmm_pos = list(map(lambda x: np.array(x), np_pos)) * omm_unit.angstrom return openmm_pos @@ -429,20 +418,18 @@ def to_pdb_file(self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes] # write file if not isinstance(out_path, io.TextIOBase): # allows pathlike/str; we close on completion - out_file = open(out_path, mode='w') # type: ignore + out_file = open(out_path, mode="w") # type: ignore must_close = True else: out_file = out_path # type: ignore must_close = False - + try: out_path = out_file.name except AttributeError: out_path = "" - PDBFile.writeFile( - topology=openmm_top, positions=openmm_pos, file=out_file - ) + PDBFile.writeFile(topology=openmm_top, positions=openmm_pos, file=out_file) if must_close: # we only close the file if we had to open it @@ -450,9 +437,7 @@ def to_pdb_file(self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes] return out_path - def to_pdbx_file( - self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes], io.TextIOBase] - ) -> str: + def to_pdbx_file(self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes], io.TextIOBase]) -> str: """ serialize protein to pdbx file. @@ -471,19 +456,17 @@ def to_pdbx_file( # get pos: np_pos = deserialize_numpy(self.to_dict()["conformers"][0]) - openmm_pos = ( - list(map(lambda x: np.array(x), np_pos)) * omm_unit.angstrom - ) + openmm_pos = list(map(lambda x: np.array(x), np_pos)) * omm_unit.angstrom # write file if not isinstance(out_path, io.TextIOBase): # allows pathlike/str; we close on completion - out_file = open(out_path, mode='w') # type: ignore + out_file = open(out_path, mode="w") # type: ignore must_close = True else: out_file = out_path # type: ignore must_close = False - + try: out_path = out_file.name except AttributeError: @@ -491,7 +474,6 @@ def to_pdbx_file( PDBxFile.writeFile(topology=top, positions=openmm_pos, file=out_file) - if must_close: # we only close the file if we had to open it out_file.close() @@ -516,7 +498,7 @@ def _to_dict(self) -> dict: mi.GetName(), mi.GetResidueName(), mi.GetResidueNumber(), - 'Y' if mi.GetIsHeteroAtom() else 'N', + "Y" if mi.GetIsHeteroAtom() else "N", atom.GetFormalCharge(), ) ) @@ -526,15 +508,14 @@ def _to_dict(self) -> dict: bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), _BONDORDER_RDKIT_TO_STR[bond.GetBondType()], - 'Y' if bond.GetIsAromatic() else 'N', + "Y" if bond.GetIsAromatic() else "N", # bond.GetStereo() or "", do we need this? i.e. are openff ffs going to use cis/trans SMARTS? ) for bond in self._rdkit.GetBonds() ] conformers = [ - serialize_numpy(conf.GetPositions()) # .m_as(unit.angstrom) - for conf in self._rdkit.GetConformers() + serialize_numpy(conf.GetPositions()) for conf in self._rdkit.GetConformers() # .m_as(unit.angstrom) ] # Result diff --git a/gufe/components/smallmoleculecomponent.py b/gufe/components/smallmoleculecomponent.py index b51f2b7d..814c473b 100644 --- a/gufe/components/smallmoleculecomponent.py +++ b/gufe/components/smallmoleculecomponent.py @@ -2,16 +2,17 @@ # For details, see https://github.com/OpenFreeEnergy/gufe import logging +import warnings + # openff complains about oechem being missing, shhh -logger = logging.getLogger('openff.toolkit') +logger = logging.getLogger("openff.toolkit") logger.setLevel(logging.ERROR) from typing import Any from rdkit import Chem -from .explicitmoleculecomponent import ExplicitMoleculeComponent from ..molhashing import deserialize_numpy, serialize_numpy - +from .explicitmoleculecomponent import ExplicitMoleculeComponent _INT_TO_ATOMCHIRAL = { 0: Chem.rdchem.ChiralType.CHI_UNSPECIFIED, @@ -20,14 +21,16 @@ 3: Chem.rdchem.ChiralType.CHI_OTHER, } # support for non-tetrahedral stereo requires rdkit 2022.09.1+ -if hasattr(Chem.rdchem.ChiralType, 'CHI_TETRAHEDRAL'): - _INT_TO_ATOMCHIRAL.update({ - 4: Chem.rdchem.ChiralType.CHI_TETRAHEDRAL, - 5: Chem.rdchem.ChiralType.CHI_ALLENE, - 6: Chem.rdchem.ChiralType.CHI_SQUAREPLANAR, - 7: Chem.rdchem.ChiralType.CHI_TRIGONALBIPYRAMIDAL, - 8: Chem.rdchem.ChiralType.CHI_OCTAHEDRAL, - }) +if hasattr(Chem.rdchem.ChiralType, "CHI_TETRAHEDRAL"): + _INT_TO_ATOMCHIRAL.update( + { + 4: Chem.rdchem.ChiralType.CHI_TETRAHEDRAL, + 5: Chem.rdchem.ChiralType.CHI_ALLENE, + 6: Chem.rdchem.ChiralType.CHI_SQUAREPLANAR, + 7: Chem.rdchem.ChiralType.CHI_TRIGONALBIPYRAMIDAL, + 8: Chem.rdchem.ChiralType.CHI_OCTAHEDRAL, + } + ) _ATOMCHIRAL_TO_INT = {v: k for k, v in _INT_TO_ATOMCHIRAL.items()} @@ -53,7 +56,8 @@ 18: Chem.rdchem.BondType.DATIVEL, 19: Chem.rdchem.BondType.DATIVER, 20: Chem.rdchem.BondType.OTHER, - 21: Chem.rdchem.BondType.ZERO} + 21: Chem.rdchem.BondType.ZERO, +} _BONDTYPE_TO_INT = {v: k for k, v in _INT_TO_BONDTYPE.items()} _INT_TO_BONDSTEREO = { 0: Chem.rdchem.BondStereo.STEREONONE, @@ -61,9 +65,24 @@ 2: Chem.rdchem.BondStereo.STEREOZ, 3: Chem.rdchem.BondStereo.STEREOE, 4: Chem.rdchem.BondStereo.STEREOCIS, - 5: Chem.rdchem.BondStereo.STEREOTRANS} + 5: Chem.rdchem.BondStereo.STEREOTRANS, +} _BONDSTEREO_TO_INT = {v: k for k, v in _INT_TO_BONDSTEREO.items()} +# following the numbering in rdkit +_INT_TO_HYBRIDIZATION = { + 0: Chem.rdchem.HybridizationType.UNSPECIFIED, + 1: Chem.rdchem.HybridizationType.S, + 2: Chem.rdchem.HybridizationType.SP, + 3: Chem.rdchem.HybridizationType.SP2, + 4: Chem.rdchem.HybridizationType.SP3, + 5: Chem.rdchem.HybridizationType.SP2D, + 6: Chem.rdchem.HybridizationType.SP3D, + 7: Chem.rdchem.HybridizationType.SP3D2, + 8: Chem.rdchem.HybridizationType.OTHER, +} +_HYBRIDIZATION_TO_INT = {v: k for k, v in _INT_TO_HYBRIDIZATION.items()} + def _setprops(obj, d: dict) -> None: # add props onto rdkit "obj" (atom/bond/mol/conformer) @@ -120,8 +139,8 @@ def to_sdf(self) -> str: sdf = [Chem.MolToMolBlock(mol)] for prop in mol.GetPropNames(): val = mol.GetProp(prop) - sdf.append('> <%s>\n%s\n' % (prop, val)) - sdf.append('$$$$\n') + sdf.append(f"> <{prop}>\n{val}\n") + sdf.append("$$$$\n") return "\n".join(sdf) @classmethod @@ -210,26 +229,40 @@ def _to_dict(self) -> dict: atoms = [] for atom in self._rdkit.GetAtoms(): - atoms.append(( - atom.GetAtomicNum(), atom.GetIsotope(), atom.GetFormalCharge(), atom.GetIsAromatic(), - _ATOMCHIRAL_TO_INT[atom.GetChiralTag()], atom.GetAtomMapNum(), - atom.GetPropsAsDict(includePrivate=False), - )) - output['atoms'] = atoms + atoms.append( + ( + atom.GetAtomicNum(), + atom.GetIsotope(), + atom.GetFormalCharge(), + atom.GetIsAromatic(), + _ATOMCHIRAL_TO_INT[atom.GetChiralTag()], + atom.GetAtomMapNum(), + atom.GetPropsAsDict(includePrivate=False), + _HYBRIDIZATION_TO_INT[atom.GetHybridization()], + ) + ) + output["atoms"] = atoms bonds = [] for bond in self._rdkit.GetBonds(): - bonds.append(( - bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), _BONDTYPE_TO_INT[bond.GetBondType()], - _BONDSTEREO_TO_INT[bond.GetStereo()], - bond.GetPropsAsDict(includePrivate=False) - )) - output['bonds'] = bonds + bonds.append( + ( + bond.GetBeginAtomIdx(), + bond.GetEndAtomIdx(), + _BONDTYPE_TO_INT[bond.GetBondType()], + _BONDSTEREO_TO_INT[bond.GetStereo()], + bond.GetPropsAsDict(includePrivate=False), + ) + ) + output["bonds"] = bonds conf = self._rdkit.GetConformer() - output['conformer'] = (serialize_numpy(conf.GetPositions()), conf.GetPropsAsDict(includePrivate=False)) + output["conformer"] = ( + serialize_numpy(conf.GetPositions()), + conf.GetPropsAsDict(includePrivate=False), + ) - output['molprops'] = self._rdkit.GetPropsAsDict(includePrivate=False) + output["molprops"] = self._rdkit.GetPropsAsDict(includePrivate=False) return output @@ -239,7 +272,7 @@ def _from_dict(cls, d: dict): m = Chem.Mol() em = Chem.EditableMol(m) - for atom in d['atoms']: + for atom in d["atoms"]: a = Chem.Atom(atom[0]) a.SetIsotope(atom[1]) a.SetFormalCharge(atom[2]) @@ -247,26 +280,36 @@ def _from_dict(cls, d: dict): a.SetChiralTag(_INT_TO_ATOMCHIRAL[atom[4]]) a.SetAtomMapNum(atom[5]) _setprops(a, atom[6]) + try: + a.SetHybridization(_INT_TO_HYBRIDIZATION[atom[7]]) + except IndexError: + warnings.warn( + "The atom hybridization data was not found and has been set to unspecified. This can be" + " fixed by recreating the SmallMoleculeComponent from the rdkit molecule after running " + "sanitization." + ) + pass + em.AddAtom(a) - for bond in d['bonds']: + for bond in d["bonds"]: em.AddBond(bond[0], bond[1], _INT_TO_BONDTYPE[bond[2]]) # other fields are applied onto the ROMol m = em.GetMol() - for bond, b in zip(d['bonds'], m.GetBonds()): + for bond, b in zip(d["bonds"], m.GetBonds()): b.SetStereo(_INT_TO_BONDSTEREO[bond[3]]) _setprops(b, bond[4]) - pos = deserialize_numpy(d['conformer'][0]) + pos = deserialize_numpy(d["conformer"][0]) c = Chem.Conformer(m.GetNumAtoms()) for i, p in enumerate(pos): c.SetAtomPosition(i, p) - _setprops(c, d['conformer'][1]) + _setprops(c, d["conformer"][1]) m.AddConformer(c) - _setprops(m, d['molprops']) + _setprops(m, d["molprops"]) m.UpdatePropertyCache() @@ -275,10 +318,10 @@ def _from_dict(cls, d: dict): def copy_with_replacements(self, **replacements): # this implementation first makes a copy with the name replaced # only, then does any other replacements that are necessary - if 'name' in replacements: - name = replacements.pop('name') + if "name" in replacements: + name = replacements.pop("name") dct = self._to_dict() - dct['molprops']['ofe-name'] = name + dct["molprops"]["ofe-name"] = name obj = self._from_dict(dct) else: obj = self diff --git a/gufe/components/solventcomponent.py b/gufe/components/solventcomponent.py index a323a3c6..8b9b01e1 100644 --- a/gufe/components/solventcomponent.py +++ b/gufe/components/solventcomponent.py @@ -2,13 +2,14 @@ # For details, see https://github.com/OpenFreeEnergy/gufe from __future__ import annotations -from openff.units import unit from typing import Optional, Tuple +from openff.units import unit + from .component import Component -_CATIONS = {'Cs', 'K', 'Li', 'Na', 'Rb'} -_ANIONS = {'Cl', 'Br', 'F', 'I'} +_CATIONS = {"Cs", "K", "Li", "Na", "Rb"} +_ANIONS = {"Cl", "Br", "F", "I"} # really wanted to make this a dataclass but then can't sort & strip ion input @@ -21,18 +22,22 @@ class SolventComponent(Component): and their coordinates. This abstract representation is later made concrete by specific MD engine methods. """ + _smiles: str - _positive_ion: Optional[str] - _negative_ion: Optional[str] + _positive_ion: str | None + _negative_ion: str | None _neutralize: bool _ion_concentration: unit.Quantity - def __init__(self, *, # force kwarg usage - smiles: str = 'O', - positive_ion: str = 'Na+', - negative_ion: str = 'Cl-', - neutralize: bool = True, - ion_concentration: unit.Quantity = 0.15 * unit.molar): + def __init__( + self, + *, # force kwarg usage + smiles: str = "O", + positive_ion: str = "Na+", + negative_ion: str = "Cl-", + neutralize: bool = True, + ion_concentration: unit.Quantity = 0.15 * unit.molar, + ): """ Parameters ---------- @@ -58,26 +63,23 @@ def __init__(self, *, # force kwarg usage """ self._smiles = smiles - norm = positive_ion.strip('-+').capitalize() + norm = positive_ion.strip("-+").capitalize() if norm not in _CATIONS: raise ValueError(f"Invalid positive ion, got {positive_ion}") - positive_ion = norm + '+' + positive_ion = norm + "+" self._positive_ion = positive_ion - norm = negative_ion.strip('-+').capitalize() + norm = negative_ion.strip("-+").capitalize() if norm not in _ANIONS: raise ValueError(f"Invalid negative ion, got {negative_ion}") - negative_ion = norm + '-' + negative_ion = norm + "-" self._negative_ion = negative_ion self._neutralize = neutralize - if (not isinstance(ion_concentration, unit.Quantity) - or not ion_concentration.is_compatible_with(unit.molar)): - raise ValueError(f"ion_concentration must be given in units of" - f" concentration, got: {ion_concentration}") + if not isinstance(ion_concentration, unit.Quantity) or not ion_concentration.is_compatible_with(unit.molar): + raise ValueError(f"ion_concentration must be given in units of" f" concentration, got: {ion_concentration}") if ion_concentration.m < 0: - raise ValueError(f"ion_concentration must be positive, " - f"got: {ion_concentration}") + raise ValueError(f"ion_concentration must be positive, " f"got: {ion_concentration}") self._ion_concentration = ion_concentration @@ -91,12 +93,12 @@ def smiles(self) -> str: return self._smiles @property - def positive_ion(self) -> Optional[str]: + def positive_ion(self) -> str | None: """The cation in the solvent state""" return self._positive_ion @property - def negative_ion(self) -> Optional[str]: + def negative_ion(self) -> str | None: """The anion in the solvent state""" return self._negative_ion @@ -118,8 +120,8 @@ def total_charge(self): @classmethod def _from_dict(cls, d): """Deserialize from dict representation""" - ion_conc = d['ion_concentration'] - d['ion_concentration'] = unit.parse_expression(ion_conc) + ion_conc = d["ion_concentration"] + d["ion_concentration"] = unit.parse_expression(ion_conc) return cls(**d) @@ -127,10 +129,13 @@ def _to_dict(self): """For serialization""" ion_conc = str(self.ion_concentration) - return {'smiles': self.smiles, 'positive_ion': self.positive_ion, - 'negative_ion': self.negative_ion, - 'ion_concentration': ion_conc, - 'neutralize': self._neutralize} + return { + "smiles": self.smiles, + "positive_ion": self.positive_ion, + "negative_ion": self.negative_ion, + "ion_concentration": ion_conc, + "neutralize": self._neutralize, + } @classmethod def _defaults(cls): diff --git a/gufe/custom_codecs.py b/gufe/custom_codecs.py index a2ed55ea..7499ccd6 100644 --- a/gufe/custom_codecs.py +++ b/gufe/custom_codecs.py @@ -5,10 +5,10 @@ import datetime import functools import pathlib +from uuid import UUID import numpy as np from openff.units import DEFAULT_UNIT_REGISTRY -from uuid import UUID import gufe from gufe.custom_json import JSONCodec @@ -62,15 +62,15 @@ def is_openff_quantity_dict(dct): BYTES_CODEC = JSONCodec( cls=bytes, - to_dict=lambda obj: {'latin-1': obj.decode('latin-1')}, - from_dict=lambda dct: dct['latin-1'].encode('latin-1'), + to_dict=lambda obj: {"latin-1": obj.decode("latin-1")}, + from_dict=lambda dct: dct["latin-1"].encode("latin-1"), ) DATETIME_CODEC = JSONCodec( cls=datetime.datetime, - to_dict=lambda obj: {'isotime': obj.isoformat()}, - from_dict=lambda dct: datetime.datetime.fromisoformat(dct['isotime']), + to_dict=lambda obj: {"isotime": obj.isoformat()}, + from_dict=lambda dct: datetime.datetime.fromisoformat(dct["isotime"]), ) # Note that this has inconsistent behaviour for some generic types @@ -80,12 +80,10 @@ def is_openff_quantity_dict(dct): NPY_DTYPE_CODEC = JSONCodec( cls=np.generic, to_dict=lambda obj: { - 'dtype': str(obj.dtype), - 'bytes': obj.tobytes(), + "dtype": str(obj.dtype), + "bytes": obj.tobytes(), }, - from_dict=lambda dct: np.frombuffer( - dct['bytes'], dtype=np.dtype(dct['dtype']) - )[0], + from_dict=lambda dct: np.frombuffer(dct["bytes"], dtype=np.dtype(dct["dtype"]))[0], is_my_obj=lambda obj: isinstance(obj, np.generic), is_my_dict=is_npy_dtype_dict, ) @@ -94,13 +92,11 @@ def is_openff_quantity_dict(dct): NUMPY_CODEC = JSONCodec( cls=np.ndarray, to_dict=lambda obj: { - 'dtype': str(obj.dtype), - 'shape': list(obj.shape), - 'bytes': obj.tobytes() + "dtype": str(obj.dtype), + "shape": list(obj.shape), + "bytes": obj.tobytes(), }, - from_dict=lambda dct: np.frombuffer( - dct['bytes'], dtype=np.dtype(dct['dtype']) - ).reshape(dct['shape']) + from_dict=lambda dct: np.frombuffer(dct["bytes"], dtype=np.dtype(dct["dtype"])).reshape(dct["shape"]), ) @@ -120,7 +116,7 @@ def is_openff_quantity_dict(dct): ":is_custom:": True, "pint_unit_registry": "openff_units", }, - from_dict=lambda dct: dct['magnitude'] * DEFAULT_UNIT_REGISTRY.Quantity(dct['unit']), + from_dict=lambda dct: dct["magnitude"] * DEFAULT_UNIT_REGISTRY.Quantity(dct["unit"]), is_my_obj=lambda obj: isinstance(obj, DEFAULT_UNIT_REGISTRY.Quantity), is_my_dict=is_openff_quantity_dict, ) diff --git a/gufe/custom_json.py b/gufe/custom_json.py index fc14a9aa..535487f6 100644 --- a/gufe/custom_json.py +++ b/gufe/custom_json.py @@ -5,10 +5,11 @@ import functools import json -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union +from collections.abc import Iterable +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union -class JSONCodec(object): +class JSONCodec: """Custom JSON encoding and decoding for non-default types. Parameters @@ -32,13 +33,14 @@ class JSONCodec(object): by this decoder. Default behavior assumes usage of the default ``is_my_obj``. """ + def __init__( self, cls: Union[type, None], - to_dict: Callable[[Any], Dict], - from_dict: Callable[[Dict], Any], + to_dict: Callable[[Any], dict], + from_dict: Callable[[dict], Any], is_my_obj: Optional[Callable[[Any], bool]] = None, - is_my_dict=None + is_my_dict=None, ): if is_my_obj is None: is_my_obj = self._is_my_obj @@ -53,7 +55,7 @@ def __init__( self.is_my_dict = is_my_dict def _is_my_dict(self, dct: dict) -> bool: - expected = ['__class__', '__module__', ':is_custom:'] + expected = ["__class__", "__module__", ":is_custom:"] is_custom = all(exp in dct for exp in expected) return ( is_custom @@ -80,7 +82,7 @@ def default(self, obj: Any) -> Any: return dct return obj - def object_hook(self, dct: Dict) -> Any: + def object_hook(self, dct: dict) -> Any: if self.is_my_dict(dct): obj = self.from_dict(dct) return obj @@ -88,8 +90,8 @@ def object_hook(self, dct: Dict) -> Any: def custom_json_factory( - coding_methods: Iterable[JSONCodec] -) -> Tuple[Type[json.JSONEncoder], Type[json.JSONDecoder]]: + coding_methods: Iterable[JSONCodec], +) -> tuple[type[json.JSONEncoder], type[json.JSONDecoder]]: """Create JSONEncoder/JSONDecoder for special types. Factory method. Dynamically creates classes that enable all the provided @@ -107,6 +109,7 @@ def custom_json_factory( subclasses of JSONEncoder/JSONDecoder that use support the provided codecs """ + class CustomJSONEncoder(json.JSONEncoder): def default(self, obj): for coding_method in coding_methods: @@ -125,9 +128,7 @@ class CustomJSONDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): # technically, JSONDecoder doesn't come with an object_hook # method, which is why we pass it to super here - super(CustomJSONDecoder, self).__init__( - object_hook=self.object_hook, *args, **kwargs - ) + super().__init__(object_hook=self.object_hook, *args, **kwargs) def object_hook(self, dct): for coding_method in coding_methods: @@ -144,8 +145,8 @@ def object_hook(self, dct): return (CustomJSONEncoder, CustomJSONDecoder) -class JSONSerializerDeserializer(object): - """ +class JSONSerializerDeserializer: + r""" Tools to serialize and deserialize objects as JSON. This wrapper object is necessary so that we can register new codecs @@ -164,15 +165,17 @@ class JSONSerializerDeserializer(object): codecs : list of :class:`.JSONCodec`\s codecs supported """ + def __init__(self, codecs: Iterable[JSONCodec]): - self.codecs: List[JSONCodec] = [] + self.codecs: list[JSONCodec] = [] for codec in codecs: self.add_codec(codec) self.encoder, self.decoder = self._set_serialization() - def _set_serialization(self) -> Tuple[Type[json.JSONEncoder], - Type[json.JSONDecoder]]: + def _set_serialization( + self, + ) -> tuple[type[json.JSONEncoder], type[json.JSONDecoder]]: encoder, decoder = custom_json_factory(self.codecs) self._serializer = functools.partial(json.dumps, cls=encoder) self._deserializer = functools.partial(json.loads, cls=decoder) @@ -194,7 +197,6 @@ def add_codec(self, codec: JSONCodec): self.encoder, self.decoder = self._set_serialization() - def serializer(self, obj: Any) -> str: """Callable that dumps to JSON""" return self._serializer(obj) diff --git a/gufe/custom_typing.py b/gufe/custom_typing.py index 3a1ce2a9..38f91625 100644 --- a/gufe/custom_typing.py +++ b/gufe/custom_typing.py @@ -2,6 +2,7 @@ # For details, see https://github.com/OpenFreeEnergy/gufe from typing import TypeVar + from rdkit import Chem try: diff --git a/gufe/ligandnetwork.py b/gufe/ligandnetwork.py index c43dce57..45eac118 100644 --- a/gufe/ligandnetwork.py +++ b/gufe/ligandnetwork.py @@ -2,21 +2,24 @@ # For details, see https://github.com/OpenFreeEnergy/gufe from __future__ import annotations -from itertools import chain import json +from collections.abc import Iterable +from itertools import chain +from typing import FrozenSet, Optional + import networkx as nx -from typing import FrozenSet, Iterable, Optional -import gufe +import gufe from gufe import SmallMoleculeComponent + from .mapping import LigandAtomMapping -from .tokenization import GufeTokenizable, JSON_HANDLER +from .tokenization import JSON_HANDLER, GufeTokenizable class LigandNetwork(GufeTokenizable): """A directed graph connecting ligands according to their atom mapping. A network can be defined by specifying only edges, in which case the nodes are implicitly added. - + Parameters ---------- @@ -26,17 +29,17 @@ class LigandNetwork(GufeTokenizable): Nodes for this network. Any nodes already included as a part of the 'edges' will be ignored. Nodes not already included in 'edges' will be added as isolated, unconnected nodes. """ + def __init__( self, edges: Iterable[LigandAtomMapping], - nodes: Optional[Iterable[SmallMoleculeComponent]] = None + nodes: Iterable[SmallMoleculeComponent] | None = None, ): if nodes is None: nodes = [] self._edges = frozenset(edges) - edge_nodes = set(chain.from_iterable((e.componentA, e.componentB) - for e in edges)) + edge_nodes = set(chain.from_iterable((e.componentA, e.componentB) for e in edges)) self._nodes = frozenset(edge_nodes) | frozenset(nodes) self._graph = None @@ -45,11 +48,11 @@ def _defaults(cls): return {} def _to_dict(self) -> dict: - return {'graphml': self.to_graphml()} + return {"graphml": self.to_graphml()} @classmethod def _from_dict(cls, dct: dict): - return cls.from_graphml(dct['graphml']) + return cls.from_graphml(dct["graphml"]) @property def graph(self) -> nx.MultiDiGraph: @@ -65,20 +68,19 @@ def graph(self) -> nx.MultiDiGraph: for node in sorted(self._nodes): graph.add_node(node) for edge in sorted(self._edges): - graph.add_edge(edge.componentA, edge.componentB, object=edge, - **edge.annotations) + graph.add_edge(edge.componentA, edge.componentB, object=edge, **edge.annotations) self._graph = nx.freeze(graph) return self._graph @property - def edges(self) -> FrozenSet[LigandAtomMapping]: + def edges(self) -> frozenset[LigandAtomMapping]: """A read-only view of the edges of the Network""" return self._edges @property - def nodes(self) -> FrozenSet[SmallMoleculeComponent]: + def nodes(self) -> frozenset[SmallMoleculeComponent]: """A read-only view of the nodes of the Network""" return self._nodes @@ -93,25 +95,24 @@ def _serializable_graph(self) -> nx.Graph: # identical networks will show no changes if you diff their # serialized versions sorted_nodes = sorted(self.nodes, key=lambda m: (m.smiles, m.name)) - mol_to_label = {mol: f"mol{num}" - for num, mol in enumerate(sorted_nodes)} - - edge_data = sorted([ - ( - mol_to_label[edge.componentA], - mol_to_label[edge.componentB], - json.dumps(list(edge.componentA_to_componentB.items())), - json.dumps(edge.annotations, cls=JSON_HANDLER.encoder), - ) - for edge in self.edges - ]) + mol_to_label = {mol: f"mol{num}" for num, mol in enumerate(sorted_nodes)} + + edge_data = sorted( + [ + ( + mol_to_label[edge.componentA], + mol_to_label[edge.componentB], + json.dumps(list(edge.componentA_to_componentB.items())), + json.dumps(edge.annotations, cls=JSON_HANDLER.encoder), + ) + for edge in self.edges + ] + ) # from here, we just build the graph serializable_graph = nx.MultiDiGraph() for mol, label in mol_to_label.items(): - serializable_graph.add_node(label, - moldict=json.dumps(mol.to_dict(), - sort_keys=True)) + serializable_graph.add_node(label, moldict=json.dumps(mol.to_dict(), sort_keys=True)) for molA, molB, mapping, annotation in edge_data: serializable_graph.add_edge(molA, molB, mapping=mapping, annotations=annotation) @@ -124,15 +125,19 @@ def _from_serializable_graph(cls, graph: nx.Graph): This is the inverse of ``_serializable_graph``. """ - label_to_mol = {node: SmallMoleculeComponent.from_dict(json.loads(d)) - for node, d in graph.nodes(data='moldict')} + label_to_mol = { + node: SmallMoleculeComponent.from_dict(json.loads(d)) for node, d in graph.nodes(data="moldict") + } edges = [ - LigandAtomMapping(componentA=label_to_mol[node1], - componentB=label_to_mol[node2], - componentA_to_componentB=dict(json.loads(edge_data["mapping"])), - annotations=json.loads(edge_data.get("annotations", 'null'), cls=JSON_HANDLER.decoder) # work around old graphml files with missing edge annotations - ) + LigandAtomMapping( + componentA=label_to_mol[node1], + componentB=label_to_mol[node2], + componentA_to_componentB=dict(json.loads(edge_data["mapping"])), + annotations=json.loads( + edge_data.get("annotations", "null"), cls=JSON_HANDLER.decoder + ), # work around old graphml files with missing edge annotations + ) for node1, node2, edge_data in graph.edges(data=True) ] @@ -183,10 +188,10 @@ def enlarge_graph(self, *, edges=None, nodes=None) -> LigandNetwork: a new network adding the given edges and nodes to this network """ if edges is None: - edges = set([]) + edges = set() if nodes is None: - nodes = set([]) + nodes = set() return LigandNetwork(self.edges | set(edges), self.nodes | set(nodes)) @@ -198,7 +203,7 @@ def _to_rfe_alchemical_network( *, alchemical_label: str = "ligand", autoname=True, - autoname_prefix="" + autoname_prefix="", ) -> gufe.AlchemicalNetwork: """ Parameters @@ -228,8 +233,7 @@ def sys_from_dict(component): """ syscomps = {alchemical_label: component} other_labels = set(labels) - {alchemical_label} - syscomps.update({label: components[label] - for label in other_labels}) + syscomps.update({label: components[label] for label in other_labels}) if autoname: name = f"{component.name}_{leg_name}" @@ -246,9 +250,7 @@ def sys_from_dict(component): else: name = "" - transformation = gufe.Transformation(sysA, sysB, protocol, - mapping=edge, - name=name) + transformation = gufe.Transformation(sysA, sysB, protocol, mapping=edge, name=name) transformations.append(transformation) @@ -262,7 +264,7 @@ def to_rbfe_alchemical_network( *, autoname: bool = True, autoname_prefix: str = "easy_rbfe", - **other_components + **other_components, ) -> gufe.AlchemicalNetwork: """Convert the ligand network to an AlchemicalNetwork @@ -279,22 +281,17 @@ def to_rbfe_alchemical_network( additional non-alchemical components, keyword will be the string label for the component """ - components = { - 'protein': protein, - 'solvent': solvent, - **other_components - } + components = {"protein": protein, "solvent": solvent, **other_components} leg_labels = { "solvent": ["ligand", "solvent"], - "complex": (["ligand", "solvent", "protein"] - + list(other_components)), + "complex": (["ligand", "solvent", "protein"] + list(other_components)), } return self._to_rfe_alchemical_network( components=components, leg_labels=leg_labels, protocol=protocol, autoname=autoname, - autoname_prefix=autoname_prefix + autoname_prefix=autoname_prefix, ) # on hold until we figure out how to best hack in the PME/NoCutoff diff --git a/gufe/mapping/__init__.py b/gufe/mapping/__init__.py index bd6be51b..e77a7208 100644 --- a/gufe/mapping/__init__.py +++ b/gufe/mapping/__init__.py @@ -1,7 +1,7 @@ # This code is part of gufe and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe """Defining the relationship between different components""" -from .componentmapping import ComponentMapping -from .atom_mapping import AtomMapping from .atom_mapper import AtomMapper +from .atom_mapping import AtomMapping +from .componentmapping import ComponentMapping from .ligandatommapping import LigandAtomMapping diff --git a/gufe/mapping/atom_mapper.py b/gufe/mapping/atom_mapper.py index 0fe1c9cd..c2844979 100644 --- a/gufe/mapping/atom_mapper.py +++ b/gufe/mapping/atom_mapper.py @@ -2,6 +2,7 @@ # For details, see https://github.com/OpenFreeEnergy/gufe import abc from collections.abc import Iterator + import gufe from ..tokenization import GufeTokenizable @@ -18,10 +19,7 @@ class AtomMapper(GufeTokenizable): """ @abc.abstractmethod - def suggest_mappings(self, - A: gufe.Component, - B: gufe.Component - ) -> Iterator[AtomMapping]: + def suggest_mappings(self, A: gufe.Component, B: gufe.Component) -> Iterator[AtomMapping]: """Suggests possible mappings between two Components Suggests zero or more :class:`.AtomMapping` objects, which are possible diff --git a/gufe/mapping/atom_mapping.py b/gufe/mapping/atom_mapping.py index 58c617f3..e91bf2d4 100644 --- a/gufe/mapping/atom_mapping.py +++ b/gufe/mapping/atom_mapping.py @@ -2,10 +2,10 @@ # For details, see https://github.com/OpenFreeEnergy/gufe import abc -from collections.abc import Mapping, Iterable - +from collections.abc import Iterable, Mapping import gufe + from .componentmapping import ComponentMapping diff --git a/gufe/mapping/componentmapping.py b/gufe/mapping/componentmapping.py index 92a098f3..9c1c6965 100644 --- a/gufe/mapping/componentmapping.py +++ b/gufe/mapping/componentmapping.py @@ -11,6 +11,7 @@ class ComponentMapping(GufeTokenizable, abc.ABC): For components that are atom-based is specialised to :class:`.AtomMapping` """ + _componentA: gufe.Component _componentB: gufe.Component diff --git a/gufe/mapping/ligandatommapping.py b/gufe/mapping/ligandatommapping.py index cc8ada71..4d322000 100644 --- a/gufe/mapping/ligandatommapping.py +++ b/gufe/mapping/ligandatommapping.py @@ -4,13 +4,15 @@ import json from typing import Any, Optional + import numpy as np from numpy.typing import NDArray from gufe.components import SmallMoleculeComponent from gufe.visualization.mapping_visualization import draw_mapping -from . import AtomMapping + from ..tokenization import JSON_HANDLER +from . import AtomMapping class LigandAtomMapping(AtomMapping): @@ -21,6 +23,7 @@ class LigandAtomMapping(AtomMapping): :class:`.SmallMoleculeComponent` which stores the mapping as a dict of integers. """ + componentA: SmallMoleculeComponent componentB: SmallMoleculeComponent _annotations: dict[str, Any] @@ -31,7 +34,7 @@ def __init__( componentA: SmallMoleculeComponent, componentB: SmallMoleculeComponent, componentA_to_componentB: dict[int, int], - annotations: Optional[dict[str, Any]] = None, + annotations: dict[str, Any] | None = None, ): """ Parameters @@ -57,11 +60,9 @@ def __init__( nB = self.componentB.to_rdkit().GetNumAtoms() for i, j in componentA_to_componentB.items(): if not (0 <= i < nA): - raise ValueError(f"Got invalid index for ComponentA ({i}); " - f"must be 0 <= n < {nA}") + raise ValueError(f"Got invalid index for ComponentA ({i}); " f"must be 0 <= n < {nA}") if not (0 <= j < nB): - raise ValueError(f"Got invalid index for ComponentB ({i}); " - f"must be 0 <= n < {nB}") + raise ValueError(f"Got invalid index for ComponentB ({i}); " f"must be 0 <= n < {nB}") self._compA_to_compB = componentA_to_componentB @@ -84,48 +85,51 @@ def componentB_to_componentA(self) -> dict[int, int]: @property def componentA_unique(self): - return (i for i in range(self.componentA.to_rdkit().GetNumAtoms()) - if i not in self._compA_to_compB) + return (i for i in range(self.componentA.to_rdkit().GetNumAtoms()) if i not in self._compA_to_compB) @property def componentB_unique(self): - return (i for i in range(self.componentB.to_rdkit().GetNumAtoms()) - if i not in self._compA_to_compB.values()) + return (i for i in range(self.componentB.to_rdkit().GetNumAtoms()) if i not in self._compA_to_compB.values()) @property def annotations(self): """Any extra metadata, for example the score of a mapping""" # return a copy (including copy of nested) - return json.loads(json.dumps(self._annotations, cls=JSON_HANDLER.encoder), cls=JSON_HANDLER.decoder) + return json.loads( + json.dumps(self._annotations, cls=JSON_HANDLER.encoder), + cls=JSON_HANDLER.decoder, + ) def _to_dict(self): """Serialize to dict""" return { - 'componentA': self.componentA, - 'componentB': self.componentB, - 'componentA_to_componentB': self._compA_to_compB, - 'annotations': json.dumps(self._annotations, sort_keys=True, cls=JSON_HANDLER.encoder), + "componentA": self.componentA, + "componentB": self.componentB, + "componentA_to_componentB": self._compA_to_compB, + "annotations": json.dumps(self._annotations, sort_keys=True, cls=JSON_HANDLER.encoder), } @classmethod def _from_dict(cls, d: dict): """Deserialize from dict""" # the mapping dict gets mangled sometimes - mapping = d['componentA_to_componentB'] + mapping = d["componentA_to_componentB"] fixed = {int(k): int(v) for k, v in mapping.items()} return cls( - componentA=d['componentA'], - componentB=d['componentB'], + componentA=d["componentA"], + componentB=d["componentB"], componentA_to_componentB=fixed, - annotations=json.loads(d['annotations'], cls=JSON_HANDLER.decoder) + annotations=json.loads(d["annotations"], cls=JSON_HANDLER.decoder), ) def __repr__(self): - return (f"{self.__class__.__name__}(componentA={self.componentA!r}, " - f"componentB={self.componentB!r}, " - f"componentA_to_componentB={self._compA_to_compB!r}, " - f"annotations={self.annotations!r})") + return ( + f"{self.__class__.__name__}(componentA={self.componentA!r}, " + f"componentB={self.componentB!r}, " + f"componentA_to_componentB={self._compA_to_compB!r}, " + f"annotations={self.annotations!r})" + ) def _ipython_display_(self, d2d=None): # pragma: no-cover """ @@ -144,9 +148,16 @@ def _ipython_display_(self, d2d=None): # pragma: no-cover """ from IPython.display import Image, display - return display(Image(draw_mapping(self._compA_to_compB, - self.componentA.to_rdkit(), - self.componentB.to_rdkit(), d2d))) + return display( + Image( + draw_mapping( + self._compA_to_compB, + self.componentA.to_rdkit(), + self.componentB.to_rdkit(), + d2d, + ) + ) + ) def draw_to_file(self, fname: str, d2d=None): """ @@ -162,16 +173,26 @@ def draw_to_file(self, fname: str, d2d=None): fname : str Name of file to save atom map """ - data = draw_mapping(self._compA_to_compB, self.componentA.to_rdkit(), - self.componentB.to_rdkit(), d2d) + data = draw_mapping( + self._compA_to_compB, + self.componentA.to_rdkit(), + self.componentB.to_rdkit(), + d2d, + ) if type(data) == bytes: mode = "wb" else: mode = "w" with open(fname, mode) as f: - f.write(draw_mapping(self._compA_to_compB, self.componentA.to_rdkit(), - self.componentB.to_rdkit(), d2d)) + f.write( + draw_mapping( + self._compA_to_compB, + self.componentA.to_rdkit(), + self.componentB.to_rdkit(), + d2d, + ) + ) def with_annotations(self, annotations: dict[str, Any]) -> LigandAtomMapping: """Create a new mapping based on this one with extra annotations. @@ -187,7 +208,7 @@ def with_annotations(self, annotations: dict[str, Any]) -> LigandAtomMapping: componentA=self.componentA, componentB=self.componentB, componentA_to_componentB=self._compA_to_compB, - annotations=dict(**self.annotations, **annotations) + annotations=dict(**self.annotations, **annotations), ) def get_distances(self) -> NDArray[np.float64]: diff --git a/gufe/molhashing.py b/gufe/molhashing.py index 37bd64cf..5c8daf08 100644 --- a/gufe/molhashing.py +++ b/gufe/molhashing.py @@ -1,6 +1,7 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe import io + import numpy as np @@ -15,7 +16,7 @@ def serialize_numpy(arr: np.ndarray) -> str: np.save(npbytes, arr, allow_pickle=False) npbytes.seek(0) # latin-1 or base64? latin-1 is fewer bytes, but arguably worse on eyes - return npbytes.read().decode('latin-1') + return npbytes.read().decode("latin-1") def deserialize_numpy(arr_str: str) -> np.ndarray: @@ -25,6 +26,6 @@ def deserialize_numpy(arr_str: str) -> np.ndarray: ------- :func:`.serialize_numpy` """ - npbytes = io.BytesIO(arr_str.encode('latin-1')) + npbytes = io.BytesIO(arr_str.encode("latin-1")) npbytes.seek(0) return np.load(npbytes) diff --git a/gufe/network.py b/gufe/network.py index 59929d68..18bb06d9 100644 --- a/gufe/network.py +++ b/gufe/network.py @@ -1,17 +1,17 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe -from typing import Generator, Iterable, Optional -from typing_extensions import Self # Self is included in typing as of python 3.11 +from collections.abc import Generator, Iterable +from typing import Optional import networkx as nx -from .tokenization import GufeTokenizable +from typing_extensions import Self # Self is included in typing as of python 3.11 from .chemicalsystem import ChemicalSystem +from .tokenization import GufeTokenizable from .transformations import Transformation - class AlchemicalNetwork(GufeTokenizable): _edges: frozenset[Transformation] _nodes: frozenset[ChemicalSystem] @@ -31,6 +31,7 @@ class AlchemicalNetwork(GufeTokenizable): the individual chemical states. :class:`.ChemicalSystem` objects from Transformation objects in edges will be automatically extracted """ + def __init__( self, edges: Optional[Iterable[Transformation]] = None, @@ -47,11 +48,7 @@ def __init__( else: self._nodes = frozenset(nodes) - self._nodes = ( - self._nodes - | frozenset(e.stateA for e in self._edges) - | frozenset(e.stateB for e in self._edges) - ) + self._nodes = self._nodes | frozenset(e.stateA for e in self._edges) | frozenset(e.stateB for e in self._edges) self._graph = None @@ -60,9 +57,7 @@ def _generate_graph(edges, nodes) -> nx.MultiDiGraph: g = nx.MultiDiGraph() for transformation in edges: - g.add_edge( - transformation.stateA, transformation.stateB, object=transformation - ) + g.add_edge(transformation.stateA, transformation.stateB, object=transformation) g.add_nodes_from(nodes) @@ -99,15 +94,15 @@ def name(self) -> Optional[str]: return self._name def _to_dict(self) -> dict: - return {"nodes": sorted(self.nodes), - "edges": sorted(self.edges), - "name": self.name} + return { + "nodes": sorted(self.nodes), + "edges": sorted(self.edges), + "name": self.name, + } @classmethod def _from_dict(cls, d: dict) -> Self: - return cls(nodes=frozenset(d['nodes']), - edges=frozenset(d['edges']), - name=d.get('name')) + return cls(nodes=frozenset(d["nodes"]), edges=frozenset(d["edges"]), name=d.get("name")) @classmethod def _defaults(cls): @@ -126,7 +121,7 @@ def from_graphml(cls, str) -> Self: def _from_nx_graph(cls, nx_graph) -> Self: """Create an alchemical network from a networkx representation.""" chemical_systems = [n for n in nx_graph.nodes()] - transformations = [e[2]['object'] for e in nx_graph.edges(data=True)] + transformations = [e[2]["object"] for e in nx_graph.edges(data=True)] return cls(nodes=chemical_systems, edges=transformations) def connected_subgraphs(self) -> Generator[Self, None, None]: @@ -135,4 +130,4 @@ def connected_subgraphs(self) -> Generator[Self, None, None]: for node_group in node_groups: nx_subgraph = self.graph.subgraph(node_group) alc_subgraph = self._from_nx_graph(nx_subgraph) - yield(alc_subgraph) + yield (alc_subgraph) diff --git a/gufe/protocols/__init__.py b/gufe/protocols/__init__.py index 9b13eff6..fca9b580 100644 --- a/gufe/protocols/__init__.py +++ b/gufe/protocols/__init__.py @@ -1,4 +1,5 @@ """Defining processes for performing estimates of free energy differences""" + from .protocol import Protocol, ProtocolResult from .protocoldag import ProtocolDAG, ProtocolDAGResult, execute_DAG -from .protocolunit import ProtocolUnit, ProtocolUnitResult, ProtocolUnitFailure, Context +from .protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult diff --git a/gufe/protocols/protocol.py b/gufe/protocols/protocol.py index bfda10c8..d086a2be 100644 --- a/gufe/protocols/protocol.py +++ b/gufe/protocols/protocol.py @@ -6,15 +6,16 @@ """ import abc -from typing import Optional, Iterable, Any, Union -from openff.units import Quantity import warnings +from collections.abc import Iterable, Sized +from typing import Any, Optional, Union + +from openff.units import Quantity -from ..settings import Settings, SettingsBaseModel -from ..tokenization import GufeTokenizable, GufeKey from ..chemicalsystem import ChemicalSystem from ..mapping import ComponentMapping - +from ..settings import Settings, SettingsBaseModel +from ..tokenization import GufeKey, GufeTokenizable from .protocoldag import ProtocolDAG, ProtocolDAGResult from .protocolunit import ProtocolUnit @@ -32,19 +33,33 @@ class ProtocolResult(GufeTokenizable): - `get_uncertainty` """ - def __init__(self, **data): + def __init__(self, n_protocol_dag_results: int = 0, **data): self._data = data + if not n_protocol_dag_results >= 0: + raise ValueError("`n_protocol_dag_results` must be an integer greater than or equal to zero") + + self._n_protocol_dag_results = n_protocol_dag_results + @classmethod def _defaults(cls): return {} def _to_dict(self): - return {'data': self.data} + return {"n_protocol_dag_results": self.n_protocol_dag_results, "data": self.data} @classmethod def _from_dict(cls, dct: dict): - return cls(**dct['data']) + # TODO: remove in gufe 2.0 + try: + n_protocol_dag_results = dct["n_protocol_dag_results"] + except KeyError: + n_protocol_dag_results = 0 + return cls(n_protocol_dag_results=n_protocol_dag_results, **dct["data"]) + + @property + def n_protocol_dag_results(self) -> int: + return self._n_protocol_dag_results @property def data(self) -> dict[str, Any]: @@ -57,12 +72,10 @@ def data(self) -> dict[str, Any]: return self._data @abc.abstractmethod - def get_estimate(self) -> Quantity: - ... + def get_estimate(self) -> Quantity: ... @abc.abstractmethod - def get_uncertainty(self) -> Quantity: - ... + def get_uncertainty(self) -> Quantity: ... class Protocol(GufeTokenizable): @@ -80,6 +93,7 @@ class Protocol(GufeTokenizable): - `_gather` - `_default_settings` """ + _settings: Settings result_cls: type[ProtocolResult] """Corresponding `ProtocolResult` subclass.""" @@ -109,7 +123,7 @@ def _defaults(cls): return {} def _to_dict(self): - return {'settings': self.settings} + return {"settings": self.settings} @classmethod def _from_dict(cls, dct: dict): @@ -180,9 +194,9 @@ def create( mapping: Optional[Union[ComponentMapping, list[ComponentMapping], dict[str, ComponentMapping]]], extends: Optional[ProtocolDAGResult] = None, name: Optional[str] = None, - transformation_key: Optional[GufeKey] = None + transformation_key: Optional[GufeKey] = None, ) -> ProtocolDAG: - """Prepare a `ProtocolDAG` with all information required for execution. + r"""Prepare a `ProtocolDAG` with all information required for execution. A :class:`.ProtocolDAG` is composed of :class:`.ProtocolUnit` \s, with dependencies established between them. These form a directed, acyclic @@ -221,9 +235,10 @@ def create( A directed, acyclic graph that can be executed by a `Scheduler`. """ if isinstance(mapping, dict): - warnings.warn(("mapping input as a dict is deprecated, " - "instead use either a single Mapping or list"), - DeprecationWarning) + warnings.warn( + ("mapping input as a dict is deprecated, " "instead use either a single Mapping or list"), + DeprecationWarning, + ) mapping = list(mapping.values()) return ProtocolDAG( @@ -235,12 +250,10 @@ def create( extends=extends, ), transformation_key=transformation_key, - extends_key=extends.key if extends is not None else None + extends_key=extends.key if extends is not None else None, ) - def gather( - self, protocol_dag_results: Iterable[ProtocolDAGResult] - ) -> ProtocolResult: + def gather(self, protocol_dag_results: Iterable[ProtocolDAGResult]) -> ProtocolResult: """Gather multiple ProtocolDAGResults into a single ProtocolResult. Parameters @@ -254,12 +267,16 @@ def gather( ProtocolResult Aggregated results from many `ProtocolDAGResult`s from a given `Protocol`. """ - return self.result_cls(**self._gather(protocol_dag_results)) + # Iterable does not implement __len__ and makes no guarantees that + # protocol_dag_results is finite, checking both in method signature + # doesn't appear possible, explicitly check for __len__ through the + # Sized type + if not isinstance(protocol_dag_results, Sized): + raise ValueError("`protocol_dag_results` must implement `__len__`") + return self.result_cls(n_protocol_dag_results=len(protocol_dag_results), **self._gather(protocol_dag_results)) @abc.abstractmethod - def _gather( - self, protocol_dag_results: Iterable[ProtocolDAGResult] - ) -> dict[str, Any]: + def _gather(self, protocol_dag_results: Iterable[ProtocolDAGResult]) -> dict[str, Any]: """Method to override in custom Protocol subclasses. This method should take any number of ``ProtocolDAGResult``s produced diff --git a/gufe/protocols/protocoldag.py b/gufe/protocols/protocoldag.py index d6958b76..74e7afbd 100644 --- a/gufe/protocols/protocoldag.py +++ b/gufe/protocols/protocoldag.py @@ -2,20 +2,19 @@ # For details, see https://github.com/OpenFreeEnergy/gufe import abc -from copy import copy -from collections import defaultdict import os -from typing import Iterable, Optional, Union, Any +import shutil +from collections import defaultdict +from collections.abc import Iterable +from copy import copy from os import PathLike from pathlib import Path -import shutil +from typing import Any, Optional, Union import networkx as nx -from ..tokenization import GufeTokenizable, GufeKey -from .protocolunit import ( - ProtocolUnit, ProtocolUnitResult, ProtocolUnitFailure, Context -) +from ..tokenization import GufeKey, GufeTokenizable +from .protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult class DAGMixin: @@ -32,7 +31,7 @@ class DAGMixin: ## key of the ProtocolDAG this DAG extends _extends_key: Optional[GufeKey] - @staticmethod + @staticmethod def _build_graph(nodes): """Build dependency DAG of ProtocolUnits with input keys stored on edges""" G = nx.DiGraph() @@ -46,9 +45,7 @@ def _build_graph(nodes): @staticmethod def _iterate_dag_order(graph): - return reversed( - list(nx.lexicographical_topological_sort(graph, key=lambda pu: pu.key)) - ) + return reversed(list(nx.lexicographical_topological_sort(graph, key=lambda pu: pu.key))) @property def name(self) -> Optional[str]: @@ -104,13 +101,13 @@ class ProtocolDAGResult(GufeTokenizable, DAGMixin): There may be many of these for a given `Transformation`. Data elements from these objects are combined by `Protocol.gather` into a `ProtocolResult`. """ + _protocol_unit_results: list[ProtocolUnitResult] _unit_result_mapping: dict[ProtocolUnit, list[ProtocolUnitResult]] _result_unit_mapping: dict[ProtocolUnitResult, ProtocolUnit] - def __init__( - self, + self, *, protocol_units: list[ProtocolUnit], protocol_unit_results: list[ProtocolUnitResult], @@ -147,11 +144,13 @@ def _defaults(cls): return {} def _to_dict(self): - return {'name': self.name, - 'protocol_units': self._protocol_units, - 'protocol_unit_results': self._protocol_unit_results, - 'transformation_key': self._transformation_key, - 'extends_key': self._extends_key} + return { + "name": self.name, + "protocol_units": self._protocol_units, + "protocol_unit_results": self._protocol_unit_results, + "transformation_key": self._transformation_key, + "extends_key": self._extends_key, + } @classmethod def _from_dict(cls, dct: dict): @@ -190,7 +189,7 @@ def protocol_unit_failures(self) -> list[ProtocolUnitFailure]: # mypy can't figure out the types here, .ok() will ensure a certain type # https://mypy.readthedocs.io/en/stable/common_issues.html?highlight=cast#complex-type-tests return [r for r in self.protocol_unit_results if not r.ok()] # type: ignore - + @property def protocol_unit_successes(self) -> list[ProtocolUnitResult]: """A list of only successful `ProtocolUnit` results. @@ -252,8 +251,7 @@ def result_to_unit(self, protocol_unit_result: ProtocolUnitResult) -> ProtocolUn def ok(self) -> bool: # ensure that for every protocol unit, there is an OK result object - return all(any(pur.ok() for pur in self._unit_result_mapping[pu]) - for pu in self._protocol_units) + return all(any(pur.ok() for pur in self._unit_result_mapping[pu]) for pu in self._protocol_units) @property def terminal_protocol_unit_results(self) -> list[ProtocolUnitResult]: @@ -265,8 +263,7 @@ def terminal_protocol_unit_results(self) -> list[ProtocolUnitResult]: All ProtocolUnitResults which do not have a ProtocolUnitResult that follows on (depends) on them. """ - return [u for u in self._protocol_unit_results - if not nx.ancestors(self._result_graph, u)] + return [u for u in self._protocol_unit_results if not nx.ancestors(self._result_graph, u)] class ProtocolDAG(GufeTokenizable, DAGMixin): @@ -336,24 +333,28 @@ def _defaults(cls): return {} def _to_dict(self): - return {'name': self.name, - 'protocol_units': self.protocol_units, - 'transformation_key': self._transformation_key, - 'extends_key': self._extends_key} + return { + "name": self.name, + "protocol_units": self.protocol_units, + "transformation_key": self._transformation_key, + "extends_key": self._extends_key, + } @classmethod def _from_dict(cls, dct: dict): return cls(**dct) -def execute_DAG(protocoldag: ProtocolDAG, *, - shared_basedir: Path, - scratch_basedir: Path, - keep_shared: bool = False, - keep_scratch: bool = False, - raise_error: bool = True, - n_retries: int = 0, - ) -> ProtocolDAGResult: +def execute_DAG( + protocoldag: ProtocolDAG, + *, + shared_basedir: Path, + scratch_basedir: Path, + keep_shared: bool = False, + keep_scratch: bool = False, + raise_error: bool = True, + n_retries: int = 0, +) -> ProtocolDAGResult: """ Locally execute a full :class:`ProtocolDAG` in serial and in-process. @@ -400,21 +401,17 @@ def execute_DAG(protocoldag: ProtocolDAG, *, attempt = 0 while attempt <= n_retries: - shared = shared_basedir / f'shared_{str(unit.key)}_attempt_{attempt}' + shared = shared_basedir / f"shared_{str(unit.key)}_attempt_{attempt}" shared_paths.append(shared) shared.mkdir() - scratch = scratch_basedir / f'scratch_{str(unit.key)}_attempt_{attempt}' + scratch = scratch_basedir / f"scratch_{str(unit.key)}_attempt_{attempt}" scratch.mkdir() - context = Context(shared=shared, - scratch=scratch) + context = Context(shared=shared, scratch=scratch) # execute - result = unit.execute( - context=context, - raise_error=raise_error, - **inputs) + result = unit.execute(context=context, raise_error=raise_error, **inputs) all_results.append(result) if not keep_scratch: @@ -434,16 +431,18 @@ def execute_DAG(protocoldag: ProtocolDAG, *, shutil.rmtree(shared_path) return ProtocolDAGResult( - name=protocoldag.name, - protocol_units=protocoldag.protocol_units, - protocol_unit_results=all_results, - transformation_key=protocoldag.transformation_key, - extends_key=protocoldag.extends_key) + name=protocoldag.name, + protocol_units=protocoldag.protocol_units, + protocol_unit_results=all_results, + transformation_key=protocoldag.transformation_key, + extends_key=protocoldag.extends_key, + ) def _pu_to_pur( - inputs: Union[dict[str, Any], list[Any], ProtocolUnit], - mapping: dict[GufeKey, ProtocolUnitResult]): + inputs: Union[dict[str, Any], list[Any], ProtocolUnit], + mapping: dict[GufeKey, ProtocolUnitResult], +): """Convert each `ProtocolUnit` found within `inputs` to its corresponding `ProtocolUnitResult`. @@ -467,4 +466,3 @@ def _pu_to_pur( return mapping[inputs.key] else: return inputs - diff --git a/gufe/protocols/protocolunit.py b/gufe/protocols/protocolunit.py index 73728bdf..fba47adb 100644 --- a/gufe/protocols/protocolunit.py +++ b/gufe/protocols/protocolunit.py @@ -8,20 +8,19 @@ from __future__ import annotations import abc -from dataclasses import dataclass import datetime import sys +import tempfile import traceback import uuid +from collections.abc import Iterable +from copy import copy +from dataclasses import dataclass from os import PathLike from pathlib import Path -from copy import copy -from typing import Iterable, Tuple, List, Dict, Any, Optional, Union -import tempfile +from typing import Any, Dict, List, Optional, Tuple, Union -from ..tokenization import ( - GufeTokenizable, GufeKey, TOKENIZABLE_REGISTRY -) +from ..tokenization import TOKENIZABLE_REGISTRY, GufeKey, GufeTokenizable @dataclass @@ -30,6 +29,7 @@ class Context: `ProtocolUnit._execute`. """ + scratch: PathLike shared: PathLike @@ -55,14 +55,16 @@ class ProtocolUnitResult(GufeTokenizable): Successful result of a single :class:`ProtocolUnit` execution. """ - def __init__(self, *, - name: Optional[str] = None, - source_key: GufeKey, - inputs: Dict[str, Any], - outputs: Dict[str, Any], - start_time: Optional[datetime.datetime] = None, - end_time: Optional[datetime.datetime] = None, - ): + def __init__( + self, + *, + name: str | None = None, + source_key: GufeKey, + inputs: dict[str, Any], + outputs: dict[str, Any], + start_time: datetime.datetime | None = None, + end_time: datetime.datetime | None = None, + ): """ Parameters ---------- @@ -102,16 +104,18 @@ def _defaults(cls): return {} def _to_dict(self): - return {'name': self.name, - '_key': self.key, - 'source_key': self.source_key, - 'inputs': self.inputs, - 'outputs': self.outputs, - 'start_time': self.start_time, - 'end_time': self.end_time} + return { + "name": self.name, + "_key": self.key, + "source_key": self.source_key, + "inputs": self.inputs, + "outputs": self.outputs, + "start_time": self.start_time, + "end_time": self.end_time, + } @classmethod - def _from_dict(cls, dct: Dict): + def _from_dict(cls, dct: dict): key = dct.pop("_key") obj = cls(**dct) obj._set_key(key) @@ -138,19 +142,19 @@ def dependencies(self) -> list[ProtocolUnitResult]: """All results that this result was dependent on""" if self._dependencies is None: self._dependencies = _list_dependencies(self._inputs, ProtocolUnitResult) - return self._dependencies # type: ignore + return self._dependencies # type: ignore @staticmethod def ok() -> bool: return True @property - def start_time(self) -> Optional[datetime.datetime]: + def start_time(self) -> datetime.datetime | None: """The time execution of this Unit began""" return self._start_time @property - def end_time(self) -> Optional[datetime.datetime]: + def end_time(self) -> datetime.datetime | None: """The time at which execution of this Unit finished""" return self._end_time @@ -159,18 +163,18 @@ class ProtocolUnitFailure(ProtocolUnitResult): """Failed result of a single :class:`ProtocolUnit` execution.""" def __init__( - self, - *, - name=None, - source_key, - inputs, - outputs, - _key=None, - exception, - traceback, - start_time: Optional[datetime.datetime] = None, - end_time: Optional[datetime.datetime] = None, - ): + self, + *, + name=None, + source_key, + inputs, + outputs, + _key=None, + exception, + traceback, + start_time: datetime.datetime | None = None, + end_time: datetime.datetime | None = None, + ): """ Parameters ---------- @@ -194,17 +198,22 @@ def __init__( """ self._exception = exception self._traceback = traceback - super().__init__(name=name, source_key=source_key, inputs=inputs, outputs=outputs, - start_time=start_time, end_time=end_time) + super().__init__( + name=name, + source_key=source_key, + inputs=inputs, + outputs=outputs, + start_time=start_time, + end_time=end_time, + ) def _to_dict(self): dct = super()._to_dict() - dct.update({'exception': self.exception, - 'traceback': self.traceback}) + dct.update({"exception": self.exception, "traceback": self.traceback}) return dct @property - def exception(self) -> Tuple[str, Tuple[Any, ...]]: + def exception(self) -> tuple[str, tuple[Any, ...]]: return self._exception @property @@ -218,21 +227,17 @@ def ok() -> bool: class ProtocolUnit(GufeTokenizable): """A unit of work within a ProtocolDAG.""" - _dependencies: Optional[list[ProtocolUnit]] - def __init__( - self, - *, - name: Optional[str] = None, - **inputs - ): + _dependencies: list[ProtocolUnit] | None + + def __init__(self, *, name: str | None = None, **inputs): """Create an instance of a ProtocolUnit. Parameters ---------- name : str - Custom name to give this - **inputs + Custom name to give this + **inputs Keyword arguments, which can include other `ProtocolUnit`s on which this `ProtocolUnit` is dependent. Should be either `gufe` objects or JSON-serializables. @@ -257,28 +262,25 @@ def _defaults(cls): return {} def _to_dict(self): - return {'inputs': self.inputs, - 'name': self.name, - '_key': self.key} + return {"inputs": self.inputs, "name": self.name, "_key": self.key} @classmethod - def _from_dict(cls, dct: Dict): - _key = dct.pop('_key') + def _from_dict(cls, dct: dict): + _key = dct.pop("_key") - obj = cls(name=dct['name'], - **dct['inputs']) + obj = cls(name=dct["name"], **dct["inputs"]) obj._set_key(_key) return obj @property - def name(self) -> Optional[str]: + def name(self) -> str | None: """ Optional name for the `ProtocolUnit`. """ return self._name @property - def inputs(self) -> Dict[str, Any]: + def inputs(self) -> dict[str, Any]: """ Inputs to the `ProtocolUnit`. @@ -291,12 +293,11 @@ def dependencies(self) -> list[ProtocolUnit]: """All units that this unit is dependent on (parents)""" if self._dependencies is None: self._dependencies = _list_dependencies(self._inputs, ProtocolUnit) - return self._dependencies # type: ignore + return self._dependencies # type: ignore - def execute(self, *, - context: Context, - raise_error: bool = False, - **inputs) -> Union[ProtocolUnitResult, ProtocolUnitFailure]: + def execute( + self, *, context: Context, raise_error: bool = False, **inputs + ) -> ProtocolUnitResult | ProtocolUnitFailure: """Given `ProtocolUnitResult` s from dependencies, execute this `ProtocolUnit`. Parameters @@ -313,14 +314,18 @@ def execute(self, *, objects this unit is dependent on. """ - result: Union[ProtocolUnitResult, ProtocolUnitFailure] + result: ProtocolUnitResult | ProtocolUnitFailure start = datetime.datetime.now() try: outputs = self._execute(context, **inputs) result = ProtocolUnitResult( - name=self.name, source_key=self.key, inputs=inputs, outputs=outputs, - start_time=start, end_time=datetime.datetime.now(), + name=self.name, + source_key=self.key, + inputs=inputs, + outputs=outputs, + start_time=start, + end_time=datetime.datetime.now(), ) except KeyboardInterrupt: @@ -345,7 +350,7 @@ def execute(self, *, @staticmethod @abc.abstractmethod - def _execute(ctx: Context, **inputs) -> Dict[str, Any]: + def _execute(ctx: Context, **inputs) -> dict[str, Any]: """Method to override in custom `ProtocolUnit` subclasses. A `Context` is always given as its first argument, which provides execution @@ -362,15 +367,15 @@ def _execute(ctx: Context, **inputs) -> Dict[str, Any]: where instantiation with the subclass `MyProtocolUnit` would look like: - >>> unit = MyProtocolUnit(settings=settings_dict, + >>> unit = MyProtocolUnit(settings=settings_dict, initialization=another_protocolunit, some_arg=7, another_arg="five") - Inside of `_execute` above: + Inside of `_execute` above: - `settings`, and `some_arg`, would have their values set as given - `initialization` would get the `ProtocolUnitResult` that comes from - `another_protocolunit`'s own execution. + `another_protocolunit`'s own execution. - `another_arg` would be accessible via `inputs['another_arg']` This allows protocol developers to define how `ProtocolUnit`s are diff --git a/gufe/settings/__init__.py b/gufe/settings/__init__.py index 01cf0e78..be4c3fde 100644 --- a/gufe/settings/__init__.py +++ b/gufe/settings/__init__.py @@ -1,10 +1,4 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe """General models for defining the parameters that protocols use""" -from .models import ( - Settings, - ThermoSettings, - BaseForceFieldSettings, - OpenMMSystemGeneratorFFSettings, - SettingsBaseModel, -) +from .models import BaseForceFieldSettings, OpenMMSystemGeneratorFFSettings, Settings, SettingsBaseModel, ThermoSettings diff --git a/gufe/settings/models.py b/gufe/settings/models.py index 08d3973c..734fcb56 100644 --- a/gufe/settings/models.py +++ b/gufe/settings/models.py @@ -5,21 +5,15 @@ """ import abc +import pprint from typing import Optional, Union from openff.models.models import DefaultModel from openff.models.types import FloatQuantity from openff.units import unit -import pprint try: - from pydantic.v1 import ( - Extra, - Field, - PositiveFloat, - PrivateAttr, - validator, - ) + from pydantic.v1 import Extra, Field, PositiveFloat, PrivateAttr, validator except ImportError: from pydantic import ( Extra, @@ -28,17 +22,20 @@ PrivateAttr, validator, ) + import pydantic class SettingsBaseModel(DefaultModel): """Settings and modifications we want for all settings classes.""" + _is_frozen: bool = PrivateAttr(default_factory=lambda: False) class Config: """ :noindex: """ + extra = pydantic.Extra.forbid arbitrary_types_allowed = False smart_union = True @@ -56,8 +53,7 @@ def frozen_copy(self): def freeze_model(model): submodels = ( - mod for field in model.__fields__ - if isinstance(mod := getattr(model, field), SettingsBaseModel) + mod for field in model.__fields__ if isinstance(mod := getattr(model, field), SettingsBaseModel) ) for mod in submodels: freeze_model(mod) @@ -78,8 +74,7 @@ def unfrozen_copy(self): def unfreeze_model(model): submodels = ( - mod for field in model.__fields__ - if isinstance(mod := getattr(model, field), SettingsBaseModel) + mod for field in model.__fields__ if isinstance(mod := getattr(model, field), SettingsBaseModel) ) for mod in submodels: unfreeze_model(mod) @@ -100,7 +95,8 @@ def __setattr__(self, name, value): raise AttributeError( f"Cannot set '{name}': Settings are immutable once attached" " to a Protocol and cannot be modified. Modify Settings " - "*before* creating the Protocol.") + "*before* creating the Protocol." + ) return super().__setattr__(name, value) @@ -112,23 +108,22 @@ class ThermoSettings(SettingsBaseModel): possible. """ - temperature: FloatQuantity["kelvin"] = Field( - None, description="Simulation temperature, default units kelvin" - ) + temperature: FloatQuantity["kelvin"] = Field(None, description="Simulation temperature, default units kelvin") pressure: FloatQuantity["standard_atmosphere"] = Field( None, description="Simulation pressure, default units standard atmosphere (atm)" ) ph: Union[PositiveFloat, None] = Field(None, description="Simulation pH") - redox_potential: Optional[float] = Field( - None, description="Simulation redox potential" - ) + redox_potential: Optional[float] = Field(None, description="Simulation redox potential") class BaseForceFieldSettings(SettingsBaseModel, abc.ABC): """Base class for ForceFieldSettings objects""" + class Config: """:noindex:""" + pass + ... @@ -145,11 +140,13 @@ class OpenMMSystemGeneratorFFSettings(BaseForceFieldSettings): .. _`OpenMMForceField SystemGenerator documentation`: https://github.com/openmm/openmmforcefields#automating-force-field-management-with-systemgenerator """ + class Config: """:noindex:""" + pass - - constraints: Optional[str] = 'hbonds' + + constraints: Optional[str] = "hbonds" """Constraints to be applied to system. One of 'hbonds', 'allbonds', 'hangles' or None, default 'hbonds'""" rigid_water: bool = True @@ -169,39 +166,37 @@ class Config: small_molecule_forcefield: str = "openff-2.1.1" # other default ideas 'openff-2.0.0', 'gaff-2.11', 'espaloma-0.2.0' """Name of the force field to be used for :class:`SmallMoleculeComponent` """ - nonbonded_method = 'PME' + nonbonded_method = "PME" """ Method for treating nonbonded interactions, currently only PME and NoCutoff are allowed. Default PME. """ - nonbonded_cutoff: FloatQuantity['nanometer'] = 1.0 * unit.nanometer + nonbonded_cutoff: FloatQuantity["nanometer"] = 1.0 * unit.nanometer """ Cutoff value for short range nonbonded interactions. Default 1.0 * unit.nanometer. """ - @validator('nonbonded_method') + @validator("nonbonded_method") def allowed_nonbonded(cls, v): - if v.lower() not in ['pme', 'nocutoff']: - errmsg = ( - "Only PME and NoCutoff are allowed nonbonded_methods") + if v.lower() not in ["pme", "nocutoff"]: + errmsg = "Only PME and NoCutoff are allowed nonbonded_methods" raise ValueError(errmsg) return v - @validator('nonbonded_cutoff') + @validator("nonbonded_cutoff") def is_positive_distance(cls, v): # these are time units, not simulation steps if not v.is_compatible_with(unit.nanometer): - raise ValueError("nonbonded_cutoff must be in distance units " - "(i.e. nanometers)") + raise ValueError("nonbonded_cutoff must be in distance units " "(i.e. nanometers)") if v < 0: errmsg = "nonbonded_cutoff must be a positive value" raise ValueError(errmsg) return v - @validator('constraints') + @validator("constraints") def constraint_check(cls, v): - allowed = {'hbonds', 'hangles', 'allbonds'} + allowed = {"hbonds", "hangles", "allbonds"} if not (v is None or v.lower() in allowed): raise ValueError(f"Bad constraints value, use one of {allowed}") @@ -217,6 +212,7 @@ class Settings(SettingsBaseModel): Protocols can subclass this to extend this to cater for their additional settings. """ + forcefield_settings: BaseForceFieldSettings thermo_settings: ThermoSettings diff --git a/gufe/storage/__init__.py b/gufe/storage/__init__.py index 27bcd5ae..853c55a2 100644 --- a/gufe/storage/__init__.py +++ b/gufe/storage/__init__.py @@ -1 +1 @@ -"""How to store objects across simulation campaigns""" \ No newline at end of file +"""How to store objects across simulation campaigns""" diff --git a/gufe/storage/errors.py b/gufe/storage/errors.py index ca3b9889..079e452a 100644 --- a/gufe/storage/errors.py +++ b/gufe/storage/errors.py @@ -1,8 +1,10 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe + class ExternalResourceError(Exception): """Base class for errors due to problems with external resources""" + # TODO: is it necessary to have a base class here? Would you ever have # one catch that handles both subclass errors? diff --git a/gufe/storage/externalresource/base.py b/gufe/storage/externalresource/base.py index f72a3de6..ae32d75a 100644 --- a/gufe/storage/externalresource/base.py +++ b/gufe/storage/externalresource/base.py @@ -1,18 +1,16 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe import abc +import dataclasses +import glob import hashlib -import pathlib -import shutil import io import os -import glob -from typing import Union, Tuple, ContextManager -import dataclasses +import pathlib +import shutil +from typing import ContextManager, Tuple, Union -from ..errors import ( - MissingExternalResourceError, ChangedExternalResourceError -) +from ..errors import ChangedExternalResourceError, MissingExternalResourceError @dataclasses.dataclass @@ -31,6 +29,7 @@ class _ForceContext: Filelike objects can often be used with explicit open/close. This requires the returned byteslike to be consumed as a context manager. """ + def __init__(self, context): self._context = context diff --git a/gufe/storage/externalresource/filestorage.py b/gufe/storage/externalresource/filestorage.py index 5330f88b..e2b5a54b 100644 --- a/gufe/storage/externalresource/filestorage.py +++ b/gufe/storage/externalresource/filestorage.py @@ -1,16 +1,13 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe +import os import pathlib import shutil -import os -from typing import Union, Tuple, ContextManager +from typing import ContextManager, Tuple, Union +from ..errors import ChangedExternalResourceError, MissingExternalResourceError from .base import ExternalStorage -from ..errors import ( - MissingExternalResourceError, ChangedExternalResourceError -) - # TODO: this should use pydantic to check init inputs class FileStorage(ExternalStorage): @@ -21,10 +18,7 @@ def _exists(self, location): return self._as_path(location).exists() def __eq__(self, other): - return ( - isinstance(other, FileStorage) - and self.root_dir == other.root_dir - ) + return isinstance(other, FileStorage) and self.root_dir == other.root_dir def _store_bytes(self, location, byte_data): path = self._as_path(location) @@ -32,7 +26,7 @@ def _store_bytes(self, location, byte_data): filename = path.name # TODO: add some stuff here to catch permissions-based errors directory.mkdir(parents=True, exist_ok=True) - with open(path, mode='wb') as f: + with open(path, mode="wb") as f: f.write(byte_data) def _store_path(self, location, path): @@ -55,9 +49,7 @@ def _delete(self, location): if self.exists(location): path.unlink() else: - raise MissingExternalResourceError( - f"Unable to delete '{str(path)}': File does not exist" - ) + raise MissingExternalResourceError(f"Unable to delete '{str(path)}': File does not exist") def _as_path(self, location): return self.root_dir / pathlib.Path(location) @@ -73,6 +65,6 @@ def _get_filename(self, location): def _load_stream(self, location): try: - return open(self._as_path(location), 'rb') + return open(self._as_path(location), "rb") except OSError as e: raise MissingExternalResourceError(str(e)) diff --git a/gufe/storage/externalresource/memorystorage.py b/gufe/storage/externalresource/memorystorage.py index 40c05087..e8ee07d3 100644 --- a/gufe/storage/externalresource/memorystorage.py +++ b/gufe/storage/externalresource/memorystorage.py @@ -1,17 +1,15 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe import io -from typing import Union, Tuple, ContextManager +from typing import ContextManager, Tuple, Union +from ..errors import ChangedExternalResourceError, MissingExternalResourceError from .base import ExternalStorage -from ..errors import ( - MissingExternalResourceError, ChangedExternalResourceError -) - class MemoryStorage(ExternalStorage): """Not for production use, but potentially useful in testing""" + def __init__(self): self._data = {} @@ -22,9 +20,7 @@ def _delete(self, location): try: del self._data[location] except KeyError: - raise MissingExternalResourceError( - f"Unable to delete '{location}': key does not exist" - ) + raise MissingExternalResourceError(f"Unable to delete '{location}': key does not exist") def __eq__(self, other): return self is other @@ -34,7 +30,7 @@ def _store_bytes(self, location, byte_data): return location, self.get_metadata(location) def _store_path(self, location, path): - with open(path, 'rb') as f: + with open(path, "rb") as f: byte_data = f.read() return self._store_bytes(location, byte_data) diff --git a/gufe/tests/conftest.py b/gufe/tests/conftest.py index 3080f469..63dd30a0 100644 --- a/gufe/tests/conftest.py +++ b/gufe/tests/conftest.py @@ -1,15 +1,17 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe +import functools import importlib.resources +import io import urllib.request from urllib.error import URLError -import io -import functools + import pytest +from openff.units import unit from rdkit import Chem from rdkit.Chem import AllChem -from openff.units import unit + import gufe from gufe.tests.test_protocol import DummyProtocol @@ -22,7 +24,7 @@ class URLFileLike: - def __init__(self, url, encoding='utf-8'): + def __init__(self, url, encoding="utf-8"): self.url = url self.encoding = encoding self.data = None @@ -39,21 +41,21 @@ def __call__(self): def get_test_filename(filename): - with importlib.resources.path('gufe.tests.data', filename) as file: + with importlib.resources.path("gufe.tests.data", filename) as file: return str(file) _benchmark_pdb_names = [ - "cmet_protein", - "hif2a_protein", - "mcl1_protein", - "p38_protein", - "ptp1b_protein", - "syk_protein", - "thrombin_protein", - "tnsk2_protein", - "tyk2_protein", - ] + "cmet_protein", + "hif2a_protein", + "mcl1_protein", + "p38_protein", + "ptp1b_protein", + "syk_protein", + "thrombin_protein", + "tnsk2_protein", + "tyk2_protein", +] _pl_benchmark_url_pattern = ( @@ -62,14 +64,10 @@ def get_test_filename(filename): PDB_BENCHMARK_LOADERS = { - name: URLFileLike(url=_pl_benchmark_url_pattern.format(name=name)) - for name in _benchmark_pdb_names + name: URLFileLike(url=_pl_benchmark_url_pattern.format(name=name)) for name in _benchmark_pdb_names } -PDB_FILE_LOADERS = { - name: lambda: get_test_filename(name) - for name in ["181l.pdb"] -} +PDB_FILE_LOADERS = {name: lambda: get_test_filename(name) for name in ["181l.pdb"]} ALL_PDB_LOADERS = dict(**PDB_BENCHMARK_LOADERS, **PDB_FILE_LOADERS) @@ -82,78 +80,72 @@ def ethane_sdf(): @pytest.fixture def toluene_mol2_path(): - with importlib.resources.path('gufe.tests.data', 'toluene.mol2') as f: + with importlib.resources.path("gufe.tests.data", "toluene.mol2") as f: yield str(f) @pytest.fixture def multi_molecule_sdf(): - fn = 'multi_molecule.sdf' - with importlib.resources.path('gufe.tests.data', fn) as f: + fn = "multi_molecule.sdf" + with importlib.resources.path("gufe.tests.data", fn) as f: yield str(f) @pytest.fixture def PDB_181L_path(): - with importlib.resources.path('gufe.tests.data', '181l.pdb') as f: + with importlib.resources.path("gufe.tests.data", "181l.pdb") as f: yield str(f) @pytest.fixture def PDB_181L_OpenMMClean_path(): - with importlib.resources.path('gufe.tests.data', - '181l_openmmClean.pdb') as f: + with importlib.resources.path("gufe.tests.data", "181l_openmmClean.pdb") as f: yield str(f) @pytest.fixture def offxml_settings_path(): - with importlib.resources.path('gufe.tests.data', 'offxml_settings.json') as f: + with importlib.resources.path("gufe.tests.data", "offxml_settings.json") as f: yield str(f) @pytest.fixture def all_settings_path(): - with importlib.resources.path('gufe.tests.data', 'all_settings.json') as f: + with importlib.resources.path("gufe.tests.data", "all_settings.json") as f: yield str(f) @pytest.fixture def PDB_thrombin_path(): - with importlib.resources.path('gufe.tests.data', - 'thrombin_protein.pdb') as f: + with importlib.resources.path("gufe.tests.data", "thrombin_protein.pdb") as f: yield str(f) @pytest.fixture def PDBx_181L_path(): - with importlib.resources.path('gufe.tests.data', - '181l.cif') as f: + with importlib.resources.path("gufe.tests.data", "181l.cif") as f: yield str(f) @pytest.fixture def PDBx_181L_openMMClean_path(): - with importlib.resources.path('gufe.tests.data', - '181l_openmmClean.cif') as f: + with importlib.resources.path("gufe.tests.data", "181l_openmmClean.cif") as f: yield str(f) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_modifications(): - with importlib.resources.path('gufe.tests.data', - 'benzene_modifications.sdf') as f: + with importlib.resources.path("gufe.tests.data", "benzene_modifications.sdf") as f: supp = Chem.SDMolSupplier(str(f), removeHs=False) mols = list(supp) - return {m.GetProp('_Name'): m for m in mols} + return {m.GetProp("_Name"): m for m in mols} -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_transforms(benzene_modifications): - return {k: gufe.SmallMoleculeComponent(v) - for k, v in benzene_modifications.items()} + return {k: gufe.SmallMoleculeComponent(v) for k, v in benzene_modifications.items()} @pytest.fixture @@ -168,7 +160,7 @@ def toluene(benzene_modifications): @pytest.fixture def phenol(benzene_modifications): - return gufe.SmallMoleculeComponent(benzene_modifications['phenol']) + return gufe.SmallMoleculeComponent(benzene_modifications["phenol"]) @pytest.fixture @@ -253,9 +245,10 @@ def absolute_transformation(solvated_ligand, solvated_complex): def complex_equilibrium(solvated_complex): return gufe.NonTransformation( solvated_complex, - protocol=DummyProtocol(settings=DummyProtocol.default_settings()) + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), ) + @pytest.fixture def benzene_variants_star_map_transformations( benzene, @@ -269,26 +262,20 @@ def benzene_variants_star_map_transformations( solv_comp, ): - variants = [toluene, phenol, benzonitrile, anisole, benzaldehyde, - styrene] + variants = [toluene, phenol, benzonitrile, anisole, benzaldehyde, styrene] # define the solvent chemical systems and transformations between # benzene and the others solvated_ligands = {} solvated_ligand_transformations = {} - solvated_ligands["benzene"] = gufe.ChemicalSystem( - {"solvent": solv_comp, "ligand": benzene}, name="benzene-solvent" - ) + solvated_ligands["benzene"] = gufe.ChemicalSystem({"solvent": solv_comp, "ligand": benzene}, name="benzene-solvent") for ligand in variants: solvated_ligands[ligand.name] = gufe.ChemicalSystem( - {"solvent": solv_comp, "ligand": ligand}, - name=f"{ligand.name}-solvnet" + {"solvent": solv_comp, "ligand": ligand}, name=f"{ligand.name}-solvnet" ) - solvated_ligand_transformations[ - ("benzene", ligand.name) - ] = gufe.Transformation( + solvated_ligand_transformations[("benzene", ligand.name)] = gufe.Transformation( solvated_ligands["benzene"], solvated_ligands[ligand.name], protocol=DummyProtocol(settings=DummyProtocol.default_settings()), @@ -310,9 +297,7 @@ def benzene_variants_star_map_transformations( {"protein": prot_comp, "solvent": solv_comp, "ligand": ligand}, name=f"{ligand.name}-complex", ) - solvated_complex_transformations[ - ("benzene", ligand.name) - ] = gufe.Transformation( + solvated_complex_transformations[("benzene", ligand.name)] = gufe.Transformation( solvated_complexes["benzene"], solvated_complexes[ligand.name], protocol=DummyProtocol(settings=DummyProtocol.default_settings()), @@ -325,7 +310,8 @@ def benzene_variants_star_map_transformations( @pytest.fixture def benzene_variants_star_map(benzene_variants_star_map_transformations): solvated_ligand_transformations, solvated_complex_transformations = benzene_variants_star_map_transformations - return gufe.AlchemicalNetwork(solvated_ligand_transformations+solvated_complex_transformations) + return gufe.AlchemicalNetwork(solvated_ligand_transformations + solvated_complex_transformations) + @pytest.fixture def benzene_variants_ligand_star_map(benzene_variants_star_map_transformations): diff --git a/gufe/tests/data/__init__.py b/gufe/tests/data/__init__.py index 3bf5fbe3..d9960422 100644 --- a/gufe/tests/data/__init__.py +++ b/gufe/tests/data/__init__.py @@ -1,2 +1,2 @@ # This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/gufe \ No newline at end of file +# For details, see https://github.com/OpenFreeEnergy/gufe diff --git a/gufe/tests/data/all_settings.json b/gufe/tests/data/all_settings.json index 2218fba7..d50c1a79 100644 --- a/gufe/tests/data/all_settings.json +++ b/gufe/tests/data/all_settings.json @@ -22,4 +22,4 @@ "redox_potential": null }, "protocol_settings": null -} \ No newline at end of file +} diff --git a/gufe/tests/data/ligand_network.graphml b/gufe/tests/data/ligand_network.graphml index 31fe9e56..d331b2f6 100644 --- a/gufe/tests/data/ligand_network.graphml +++ b/gufe/tests/data/ligand_network.graphml @@ -4,13 +4,13 @@ - {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}], [6, 0, 0, false, 0, 0, {}]], "bonds": [[0, 1, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (2, 3), } \n\u0000\u0000\u0000\u0000\u0000\u0000\u00e8\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0090<\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00e8?\u0000\u0000\u0000\u0000\u0000\u0000\u0090\u00bc\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}} + {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}, 4], [6, 0, 0, false, 0, 0, {}, 4]], "bonds": [[0, 1, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (2, 3), } \n\u0000\u0000\u0000\u0000\u0000\u0000\u00e8\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0090<\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00e8?\u0000\u0000\u0000\u0000\u0000\u0000\u0090\u00bc\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}} - {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}], [6, 0, 0, false, 0, 0, {}], [8, 0, 0, false, 0, 0, {}]], "bonds": [[0, 1, 1, 0, {}], [1, 2, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (3, 3), } \n\u00809B.\u00dc\u00c8\u00f4\u00bf\u00f5\u00ff\u00ff\u00ff\u00ff\u00ff\u00cf\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0001\u0000\u0000\u0000\u0000\u0000\u00e0?\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00809B.\u00dc\u00c8\u00f4?\u0006\u0000\u0000\u0000\u0000\u0000\u00d0\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}} + {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}, 4], [6, 0, 0, false, 0, 0, {}, 4], [8, 0, 0, false, 0, 0, {}, 4]], "bonds": [[0, 1, 1, 0, {}], [1, 2, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (3, 3), } \n\u00809B.\u00dc\u00c8\u00f4\u00bf\u00f5\u00ff\u00ff\u00ff\u00ff\u00ff\u00cf\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0001\u0000\u0000\u0000\u0000\u0000\u00e0?\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00809B.\u00dc\u00c8\u00f4?\u0006\u0000\u0000\u0000\u0000\u0000\u00d0\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}} - {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}], [8, 0, 0, false, 0, 0, {}]], "bonds": [[0, 1, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (2, 3), } \n\u0000\u0000\u0000\u0000\u0000\u0000\u00e8\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0090<\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00e8?\u0000\u0000\u0000\u0000\u0000\u0000\u0090\u00bc\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}} + {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}, 4], [8, 0, 0, false, 0, 0, {}, 4]], "bonds": [[0, 1, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (2, 3), } \n\u0000\u0000\u0000\u0000\u0000\u0000\u00e8\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0090<\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00e8?\u0000\u0000\u0000\u0000\u0000\u0000\u0090\u00bc\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}} [[0, 0]] diff --git a/gufe/tests/dev/serialization_test_templates.py b/gufe/tests/dev/serialization_test_templates.py index 95d45a85..6d1f7e56 100644 --- a/gufe/tests/dev/serialization_test_templates.py +++ b/gufe/tests/dev/serialization_test_templates.py @@ -4,7 +4,8 @@ from rdkit import Chem from rdkit.Chem import AllChem -from gufe import SmallMoleculeComponent, LigandNetwork, LigandAtomMapping + +from gufe import LigandAtomMapping, LigandNetwork, SmallMoleculeComponent def mol_from_smiles(smiles: str) -> Chem.Mol: diff --git a/gufe/tests/storage/test_externalresource.py b/gufe/tests/storage/test_externalresource.py index 43736991..adfe9170 100644 --- a/gufe/tests/storage/test_externalresource.py +++ b/gufe/tests/storage/test_externalresource.py @@ -1,13 +1,12 @@ -import pytest -import pathlib import hashlib import os +import pathlib from unittest import mock +import pytest + +from gufe.storage.errors import ChangedExternalResourceError, MissingExternalResourceError from gufe.storage.externalresource import FileStorage, MemoryStorage -from gufe.storage.errors import ( - MissingExternalResourceError, ChangedExternalResourceError -) # NOTE: Tests for the abstract base are just part of the tests of its # subclasses @@ -21,43 +20,46 @@ def file_storage(tmp_path): * foo.txt : contents "bar" * with/directory.txt : contents "in a directory" """ - with open(tmp_path / 'foo.txt', 'wb') as foo: - foo.write("bar".encode("utf-8")) + with open(tmp_path / "foo.txt", "wb") as foo: + foo.write(b"bar") - inner_dir = tmp_path / 'with' + inner_dir = tmp_path / "with" inner_dir.mkdir() - with open(inner_dir / 'directory.txt', 'wb') as with_dir: - with_dir.write("in a directory".encode("utf-8")) + with open(inner_dir / "directory.txt", "wb") as with_dir: + with_dir.write(b"in a directory") return FileStorage(tmp_path) class TestFileStorage: - @pytest.mark.parametrize('filename, expected', [ - ('foo.txt', True), - ('notexisting.txt', False), - ('with/directory.txt', True), - ]) + @pytest.mark.parametrize( + "filename, expected", + [ + ("foo.txt", True), + ("notexisting.txt", False), + ("with/directory.txt", True), + ], + ) def test_exists(self, filename, expected, file_storage): assert file_storage.exists(filename) == expected def test_store_bytes(self, file_storage): fileloc = file_storage.root_dir / "bar.txt" assert not fileloc.exists() - as_bytes = "This is bar".encode('utf-8') + as_bytes = b"This is bar" file_storage.store_bytes("bar.txt", as_bytes) assert fileloc.exists() - with open(fileloc, 'rb') as f: + with open(fileloc, "rb") as f: assert as_bytes == f.read() - @pytest.mark.parametrize('nested', [True, False]) + @pytest.mark.parametrize("nested", [True, False]) def test_store_path(self, file_storage, nested): orig_file = file_storage.root_dir / ".hidden" / "bar.txt" orig_file.parent.mkdir() - as_bytes = "This is bar".encode('utf-8') - with open(orig_file, 'wb') as f: + as_bytes = b"This is bar" + with open(orig_file, "wb") as f: f.write(as_bytes) nested_dir = "nested" if nested else "" @@ -67,7 +69,7 @@ def test_store_path(self, file_storage, nested): file_storage.store_path(fileloc, orig_file) assert fileloc.exists() - with open(fileloc, 'rb') as f: + with open(fileloc, "rb") as f: assert as_bytes == f.read() def test_eq(self, tmp_path): @@ -82,25 +84,28 @@ def test_delete(self, file_storage): file_storage.delete("foo.txt") assert not path.exists() - @pytest.mark.parametrize('prefix,expected', [ - ("", {'foo.txt', 'foo_dir/a.txt', 'foo_dir/b.txt'}), - ("foo", {'foo.txt', 'foo_dir/a.txt', 'foo_dir/b.txt'}), - ("foo_dir/", {'foo_dir/a.txt', 'foo_dir/b.txt'}), - ("foo_dir/a", {'foo_dir/a.txt'}), - ("foo_dir/a.txt", {'foo_dir/a.txt'}), - ("baz", set()), - ]) + @pytest.mark.parametrize( + "prefix,expected", + [ + ("", {"foo.txt", "foo_dir/a.txt", "foo_dir/b.txt"}), + ("foo", {"foo.txt", "foo_dir/a.txt", "foo_dir/b.txt"}), + ("foo_dir/", {"foo_dir/a.txt", "foo_dir/b.txt"}), + ("foo_dir/a", {"foo_dir/a.txt"}), + ("foo_dir/a.txt", {"foo_dir/a.txt"}), + ("baz", set()), + ], + ) def test_iter_contents(self, tmp_path, prefix, expected): files = [ - 'foo.txt', - 'foo_dir/a.txt', - 'foo_dir/b.txt', + "foo.txt", + "foo_dir/a.txt", + "foo_dir/b.txt", ] for file in files: path = tmp_path / file path.parent.mkdir(parents=True, exist_ok=True) assert not path.exists() - with open(path, 'wb') as f: + with open(path, "wb") as f: f.write(b"") storage = FileStorage(tmp_path) @@ -108,8 +113,7 @@ def test_iter_contents(self, tmp_path, prefix, expected): assert set(storage.iter_contents(prefix)) == expected def test_delete_error_not_existing(self, file_storage): - with pytest.raises(MissingExternalResourceError, - match="does not exist"): + with pytest.raises(MissingExternalResourceError, match="does not exist"): file_storage.delete("baz.txt") def test_get_filename(self, file_storage): @@ -119,7 +123,7 @@ def test_get_filename(self, file_storage): def test_load_stream(self, file_storage): with file_storage.load_stream("foo.txt") as f: - results = f.read().decode('utf-8') + results = f.read().decode("utf-8") assert results == "bar" @@ -130,25 +134,24 @@ def test_load_stream_error_missing(self, file_storage): class TestMemoryStorage: def setup_method(self): - self.contents = {'path/to/foo.txt': 'bar'.encode('utf-8')} + self.contents = {"path/to/foo.txt": b"bar"} self.storage = MemoryStorage() self.storage._data = dict(self.contents) - @pytest.mark.parametrize('expected', [True, False]) + @pytest.mark.parametrize("expected", [True, False]) def test_exists(self, expected): path = "path/to/foo.txt" if expected else "path/to/bar.txt" assert self.storage.exists(path) is expected def test_delete(self): # checks internal state - assert 'path/to/foo.txt' in self.storage._data - self.storage.delete('path/to/foo.txt') - assert 'path/to/foo.txt' not in self.storage._data + assert "path/to/foo.txt" in self.storage._data + self.storage.delete("path/to/foo.txt") + assert "path/to/foo.txt" not in self.storage._data def test_delete_error_not_existing(self): - with pytest.raises(MissingExternalResourceError, - match="Unable to delete"): - self.storage.delete('does/not/exist.txt') + with pytest.raises(MissingExternalResourceError, match="Unable to delete"): + self.storage.delete("does/not/exist.txt") def test_store_bytes(self): storage = MemoryStorage() @@ -162,7 +165,7 @@ def test_store_path(self, tmp_path): for label, data in self.contents.items(): path = tmp_path / label path.parent.mkdir(parents=True, exist_ok=True) - with open(path, mode='wb') as f: + with open(path, mode="wb") as f: f.write(data) storage.store_path(label, path) @@ -174,14 +177,17 @@ def test_eq(self): assert reference == reference assert reference != MemoryStorage() - @pytest.mark.parametrize('prefix,expected', [ - ("", {'foo.txt', 'foo_dir/a.txt', 'foo_dir/b.txt'}), - ("foo", {'foo.txt', 'foo_dir/a.txt', 'foo_dir/b.txt'}), - ("foo_dir/", {'foo_dir/a.txt', 'foo_dir/b.txt'}), - ("foo_dir/a", {'foo_dir/a.txt'}), - ("foo_dir/a.txt", {'foo_dir/a.txt'}), - ("baz", set()), - ]) + @pytest.mark.parametrize( + "prefix,expected", + [ + ("", {"foo.txt", "foo_dir/a.txt", "foo_dir/b.txt"}), + ("foo", {"foo.txt", "foo_dir/a.txt", "foo_dir/b.txt"}), + ("foo_dir/", {"foo_dir/a.txt", "foo_dir/b.txt"}), + ("foo_dir/a", {"foo_dir/a.txt"}), + ("foo_dir/a.txt", {"foo_dir/a.txt"}), + ("baz", set()), + ], + ) def test_iter_contents(self, prefix, expected): storage = MemoryStorage() storage._data = { @@ -198,6 +204,6 @@ def test_get_filename(self): def test_load_stream(self): path = "path/to/foo.txt" with self.storage.load_stream(path) as f: - results = f.read().decode('utf-8') + results = f.read().decode("utf-8") assert results == "bar" diff --git a/gufe/tests/test_alchemicalnetwork.py b/gufe/tests/test_alchemicalnetwork.py index 8231c51d..d65995e0 100644 --- a/gufe/tests/test_alchemicalnetwork.py +++ b/gufe/tests/test_alchemicalnetwork.py @@ -1,8 +1,8 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import pytest import networkx as nx +import pytest from gufe import AlchemicalNetwork, ChemicalSystem, Transformation @@ -56,16 +56,16 @@ def test_connected_subgraphs_multiple_subgraphs(self, benzene_variants_star_map) subgraphs = [subgraph for subgraph in alnet.connected_subgraphs()] - assert set([len(subgraph.nodes) for subgraph in subgraphs]) == {6,7,1} + assert {len(subgraph.nodes) for subgraph in subgraphs} == {6, 7, 1} # which graph has the removed node is not deterministic, so we just # check that one graph is all-solvent and the other is all-protein for subgraph in subgraphs: components = [frozenset(n.components.keys()) for n in subgraph.nodes] - if {'solvent','protein','ligand'} in components: - assert set(components) == {frozenset({'solvent','protein','ligand'})} + if {"solvent", "protein", "ligand"} in components: + assert set(components) == {frozenset({"solvent", "protein", "ligand"})} else: - assert set(components) == {frozenset({'solvent','ligand'})} + assert set(components) == {frozenset({"solvent", "ligand"})} def test_connected_subgraphs_one_subgraph(self, benzene_variants_ligand_star_map): """Return the same network if it only contains one connected component.""" diff --git a/gufe/tests/test_chemicalsystem.py b/gufe/tests/test_chemicalsystem.py index 661ab19c..414cd96b 100644 --- a/gufe/tests/test_chemicalsystem.py +++ b/gufe/tests/test_chemicalsystem.py @@ -1,62 +1,56 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import pytest import numpy as np +import pytest from gufe import ChemicalSystem from .test_tokenization import GufeTokenizableTestsMixin + def test_ligand_construction(solv_comp, toluene_ligand_comp): # sanity checks on construction state = ChemicalSystem( - {'solvent': solv_comp, - 'ligand': toluene_ligand_comp}, + {"solvent": solv_comp, "ligand": toluene_ligand_comp}, ) assert len(state.components) == 2 assert len(state) == 2 - assert list(state) == ['solvent', 'ligand'] + assert list(state) == ["solvent", "ligand"] - assert state.components['solvent'] == solv_comp - assert state.components['ligand'] == toluene_ligand_comp - assert state['solvent'] == solv_comp - assert state['ligand'] == toluene_ligand_comp + assert state.components["solvent"] == solv_comp + assert state.components["ligand"] == toluene_ligand_comp + assert state["solvent"] == solv_comp + assert state["ligand"] == toluene_ligand_comp def test_complex_construction(prot_comp, solv_comp, toluene_ligand_comp): # sanity checks on construction state = ChemicalSystem( - {'protein': prot_comp, - 'solvent': solv_comp, - 'ligand': toluene_ligand_comp}, + {"protein": prot_comp, "solvent": solv_comp, "ligand": toluene_ligand_comp}, ) assert len(state.components) == 3 assert len(state) == 3 - assert list(state) == ['protein', 'solvent', 'ligand'] + assert list(state) == ["protein", "solvent", "ligand"] - assert state.components['protein'] == prot_comp - assert state.components['solvent'] == solv_comp - assert state.components['ligand'] == toluene_ligand_comp - assert state['protein'] == prot_comp - assert state['solvent'] == solv_comp - assert state['ligand'] == toluene_ligand_comp + assert state.components["protein"] == prot_comp + assert state.components["solvent"] == solv_comp + assert state.components["ligand"] == toluene_ligand_comp + assert state["protein"] == prot_comp + assert state["solvent"] == solv_comp + assert state["ligand"] == toluene_ligand_comp def test_hash_and_eq(prot_comp, solv_comp, toluene_ligand_comp): - c1 = ChemicalSystem({'protein': prot_comp, - 'solvent': solv_comp, - 'ligand': toluene_ligand_comp}) + c1 = ChemicalSystem({"protein": prot_comp, "solvent": solv_comp, "ligand": toluene_ligand_comp}) - c2 = ChemicalSystem({'solvent': solv_comp, - 'ligand': toluene_ligand_comp, - 'protein': prot_comp}) + c2 = ChemicalSystem({"solvent": solv_comp, "ligand": toluene_ligand_comp, "protein": prot_comp}) assert c1 == c2 assert hash(c1) == hash(c2) @@ -68,8 +62,7 @@ def test_chemical_system_neq_1(solvated_complex, prot_comp): assert hash(solvated_complex) != hash(prot_comp) -def test_chemical_system_neq_2(solvated_complex, prot_comp, solv_comp, - toluene_ligand_comp): +def test_chemical_system_neq_2(solvated_complex, prot_comp, solv_comp, toluene_ligand_comp): # names are different complex2 = ChemicalSystem( {"protein": prot_comp, "solvent": solv_comp, "ligand": toluene_ligand_comp}, @@ -86,13 +79,10 @@ def test_chemical_system_neq_4(solvated_complex, solvated_ligand): assert hash(solvated_complex) != hash(solvated_ligand) -def test_chemical_system_neq_5(solvated_complex, prot_comp, solv_comp, - phenol_ligand_comp): +def test_chemical_system_neq_5(solvated_complex, prot_comp, solv_comp, phenol_ligand_comp): # same component keys, but different components complex2 = ChemicalSystem( - {'protein': prot_comp, - 'solvent': solv_comp, - 'ligand': phenol_ligand_comp}, + {"protein": prot_comp, "solvent": solv_comp, "ligand": phenol_ligand_comp}, ) assert solvated_complex != complex2 assert hash(solvated_complex) != hash(complex2) @@ -123,6 +113,5 @@ class TestChemicalSystem(GufeTokenizableTestsMixin): @pytest.fixture def instance(self, solv_comp, toluene_ligand_comp): return ChemicalSystem( - {'solvent': solv_comp, - 'ligand': toluene_ligand_comp}, - ) + {"solvent": solv_comp, "ligand": toluene_ligand_comp}, + ) diff --git a/gufe/tests/test_custom_json.py b/gufe/tests/test_custom_json.py index 9d4eebe1..dbc61ba7 100644 --- a/gufe/tests/test_custom_json.py +++ b/gufe/tests/test_custom_json.py @@ -5,17 +5,19 @@ import json import pathlib +from uuid import uuid4 import numpy as np import openff.units -from openff.units import unit import pytest from numpy import testing as npt -from uuid import uuid4 +from openff.units import unit + +from gufe import tokenization from gufe.custom_codecs import ( BYTES_CODEC, - NUMPY_CODEC, NPY_DTYPE_CODEC, + NUMPY_CODEC, OPENFF_QUANTITY_CODEC, OPENFF_UNIT_CODEC, PATH_CODEC, @@ -23,7 +25,6 @@ UUID_CODEC, ) from gufe.custom_json import JSONSerializerDeserializer, custom_json_factory -from gufe import tokenization from gufe.settings import models @@ -48,14 +49,14 @@ def test_add_existing_codec(self): assert len(serialization.codecs) == 1 -@pytest.mark.parametrize('obj', [ - np.array([[1.0, 0.0], [2.0, 3.2]]), - np.float32(1.1) -]) -@pytest.mark.parametrize('codecs', [ - [BYTES_CODEC, NUMPY_CODEC, NPY_DTYPE_CODEC], - [NPY_DTYPE_CODEC, BYTES_CODEC, NUMPY_CODEC], -]) +@pytest.mark.parametrize("obj", [np.array([[1.0, 0.0], [2.0, 3.2]]), np.float32(1.1)]) +@pytest.mark.parametrize( + "codecs", + [ + [BYTES_CODEC, NUMPY_CODEC, NPY_DTYPE_CODEC], + [NPY_DTYPE_CODEC, BYTES_CODEC, NUMPY_CODEC], + ], +) def test_numpy_codec_order_roundtrip(obj, codecs): serialization = JSONSerializerDeserializer(codecs) serialized = serialization.serializer(obj) @@ -76,15 +77,15 @@ class CustomJSONCodingTest: """ def test_default(self): - for (obj, dct) in zip(self.objs, self.dcts): + for obj, dct in zip(self.objs, self.dcts): assert self.codec.default(obj) == dct def test_object_hook(self): - for (obj, dct) in zip(self.objs, self.dcts): + for obj, dct in zip(self.objs, self.dcts): assert self.codec.object_hook(dct) == obj def _test_round_trip(self, encoder, decoder): - for (obj, dct) in zip(self.objs, self.dcts): + for obj, dct in zip(self.objs, self.dcts): json_str = json.dumps(obj, cls=encoder) reconstructed = json.loads(json_str, cls=decoder) assert reconstructed == obj @@ -107,9 +108,20 @@ def test_not_mine(self): class TestNumpyCoding(CustomJSONCodingTest): def setup_method(self): self.codec = NUMPY_CODEC - self.objs = [np.array([[1.0, 0.0], [2.0, 3.2]]), np.array([1, 0]), - np.array([1.0, 2.0, 3.0], dtype=np.float32)] - shapes = [[2, 2], [2,], [3,]] + self.objs = [ + np.array([[1.0, 0.0], [2.0, 3.2]]), + np.array([1, 0]), + np.array([1.0, 2.0, 3.0], dtype=np.float32), + ] + shapes = [ + [2, 2], + [ + 2, + ], + [ + 3, + ], + ] dtypes = [str(arr.dtype) for arr in self.objs] # may change by system? byte_reps = [arr.tobytes() for arr in self.objs] self.dcts = [ @@ -126,13 +138,13 @@ def setup_method(self): def test_object_hook(self): # to get custom equality testing for numpy - for (obj, dct) in zip(self.objs, self.dcts): + for obj, dct in zip(self.objs, self.dcts): reconstructed = self.codec.object_hook(dct) npt.assert_array_equal(reconstructed, obj) def test_round_trip(self): encoder, decoder = custom_json_factory([self.codec, BYTES_CODEC]) - for (obj, dct) in zip(self.objs, self.dcts): + for obj, dct in zip(self.objs, self.dcts): json_str = json.dumps(obj, cls=encoder) reconstructed = json.loads(json_str, cls=decoder) npt.assert_array_equal(reconstructed, obj) @@ -147,15 +159,19 @@ def setup_method(self): # Note that np.float64 is treated as a float by the # default json encode (and so returns a float not a numpy # object). - self.objs = [np.bool_(True), np.float16(1.0), np.float32(1.0), - np.complex128(1.0), - np.clongdouble(1.0), np.uint64(1)] + self.objs = [ + np.bool_(True), + np.float16(1.0), + np.float32(1.0), + np.complex128(1.0), + np.clongdouble(1.0), + np.uint64(1), + ] dtypes = [str(a.dtype) for a in self.objs] byte_reps = [a.tobytes() for a in self.objs] # Overly complicated extraction of the class name # to deal with the bool_ -> bool dtype class name problem - classes = [str(a.__class__).split("'")[1].split('.')[1] - for a in self.objs] + classes = [str(a.__class__).split("'")[1].split(".")[1] for a in self.objs] self.dcts = [ { ":is_custom:": True, @@ -221,19 +237,23 @@ def setup_method(self): ], "small_molecule_forcefield": "openff-2.1.1", "nonbonded_method": "PME", - "nonbonded_cutoff": {':is_custom:': True, - 'magnitude': 1.0, - 'pint_unit_registry': 'openff_units', - 'unit': 'nanometer'}, + "nonbonded_cutoff": { + ":is_custom:": True, + "magnitude": 1.0, + "pint_unit_registry": "openff_units", + "unit": "nanometer", + }, }, "thermo_settings": { "__class__": "ThermoSettings", "__module__": "gufe.settings.models", ":is_custom:": True, - "temperature": {":is_custom:": True, - "magnitude": 300.0, - "pint_unit_registry": "openff_units", - "unit": "kelvin"}, + "temperature": { + ":is_custom:": True, + "magnitude": 300.0, + "pint_unit_registry": "openff_units", + "unit": "kelvin", + }, "pressure": None, "ph": None, "redox_potential": None, @@ -275,9 +295,7 @@ def setup_method(self): def test_openff_quantity_array_roundtrip(): - thing = unit.Quantity.from_list([ - (i + 1.0)*unit.kelvin for i in range(10) - ]) + thing = unit.Quantity.from_list([(i + 1.0) * unit.kelvin for i in range(10)]) dumped = json.dumps(thing, cls=tokenization.JSON_HANDLER.encoder) @@ -305,9 +323,7 @@ def setup_method(self): class TestUUIDCodec(CustomJSONCodingTest): def setup_method(self): self.codec = UUID_CODEC - self.objs = [ - uuid4() - ] + self.objs = [uuid4()] self.dcts = [ { ":is_custom:": True, diff --git a/gufe/tests/test_ligand_network.py b/gufe/tests/test_ligand_network.py index 9b3be8c4..a7116a8e 100644 --- a/gufe/tests/test_ligand_network.py +++ b/gufe/tests/test_ligand_network.py @@ -1,17 +1,17 @@ # This code is part of gufe and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe -from typing import Iterable, NamedTuple -import pytest import importlib.resources -import gufe -from gufe.tests.test_protocol import DummyProtocol -from gufe import SmallMoleculeComponent, LigandNetwork, LigandAtomMapping +from collections.abc import Iterable +from typing import NamedTuple +import pytest +from networkx import NetworkXError from openff.units import unit - from rdkit import Chem -from networkx import NetworkXError +import gufe +from gufe import LigandAtomMapping, LigandNetwork, SmallMoleculeComponent +from gufe.tests.test_protocol import DummyProtocol from .test_tokenization import GufeTokenizableTestsMixin @@ -22,8 +22,10 @@ def mol_from_smiles(smi): return m + class _NetworkTestContainer(NamedTuple): """Container to facilitate network testing""" + network: LigandNetwork nodes: Iterable[SmallMoleculeComponent] edges: Iterable[LigandAtomMapping] @@ -33,8 +35,8 @@ class _NetworkTestContainer(NamedTuple): @pytest.fixture def ligandnetwork_graphml(): - with importlib.resources.path('gufe.tests.data', 'ligand_network.graphml') as file: - with open(file, 'r') as f: + with importlib.resources.path("gufe.tests.data", "ligand_network.graphml") as file: + with open(file) as f: yield f.read() @@ -96,24 +98,28 @@ def singleton_node_network(mols, std_edges): n_edges=3, ) + @pytest.fixture def real_molecules_network(benzene, phenol, toluene): """Small network with full mappings""" # benzene to phenol bp_mapping = {i: i for i in range(10)} - bp_mapping.update({10: 12, 11:11}) + bp_mapping.update({10: 12, 11: 11}) # benzene to toluene bt_mapping = {i: i + 4 for i in range(10)} bt_mapping.update({10: 2, 11: 14}) - network = gufe.LigandNetwork([ - gufe.LigandAtomMapping(benzene, toluene, bt_mapping), - gufe.LigandAtomMapping(benzene, phenol, bp_mapping), - ]) + network = gufe.LigandNetwork( + [ + gufe.LigandAtomMapping(benzene, toluene, bt_mapping), + gufe.LigandAtomMapping(benzene, phenol, bp_mapping), + ] + ) return network -@pytest.fixture(params=['simple', 'doubled_edge', 'singleton_node']) + +@pytest.fixture(params=["simple", "doubled_edge", "singleton_node"]) def network_container( request, simple_network, @@ -122,9 +128,9 @@ def network_container( ): """Fixture to allow parameterization of the network test""" network_dct = { - 'simple': simple_network, - 'doubled_edge': doubled_edge_network, - 'singleton_node': singleton_node_network, + "simple": simple_network, + "doubled_edge": doubled_edge_network, + "singleton_node": singleton_node_network, } return network_dct[request.param] @@ -140,7 +146,7 @@ def instance(self, simple_network): def test_node_type(self, network_container): n = network_container.network - assert all((isinstance(node, SmallMoleculeComponent) for node in n.nodes)) + assert all(isinstance(node, SmallMoleculeComponent) for node in n.nodes) def test_graph(self, network_container): # The NetworkX graph that comes from the ``.graph`` property should @@ -150,21 +156,19 @@ def test_graph(self, network_container): assert set(graph.nodes) == set(network_container.nodes) assert len(graph.edges) == network_container.n_edges # extract the AtomMappings from the nx edges - mappings = [ - atommapping for _, _, atommapping in graph.edges.data('object') - ] + mappings = [atommapping for _, _, atommapping in graph.edges.data("object")] assert set(mappings) == set(network_container.edges) # ensure LigandAtomMapping stored in nx edge is consistent with nx edge - for mol1, mol2, atommapping in graph.edges.data('object'): + for mol1, mol2, atommapping in graph.edges.data("object"): assert atommapping.componentA == mol1 assert atommapping.componentB == mol2 def test_graph_annotations(self, mols, std_edges): mol1, mol2, mol3 = mols edge12, edge23, edge13 = std_edges - annotated = edge12.with_annotations({'foo': 'bar'}) + annotated = edge12.with_annotations({"foo": "bar"}) network = LigandNetwork([annotated, edge23, edge13]) - assert network.graph[mol1][mol2][0]['foo'] == 'bar' + assert network.graph[mol1][mol2][0]["foo"] == "bar" def test_graph_immutability(self, mols, network_container): # The NetworkX graph that comes from that ``.graph`` property should @@ -285,7 +289,7 @@ def test_is_connected(self, simple_network): def test_is_not_connected(self, singleton_node_network): assert not singleton_node_network.network.is_connected() - @pytest.mark.parametrize('with_cofactor', [True, False]) + @pytest.mark.parametrize("with_cofactor", [True, False]) def test_to_rbfe_alchemical_network( self, real_molecules_network, @@ -297,25 +301,22 @@ def test_to_rbfe_alchemical_network( # obviously, this particular set of ligands with this particular # protein makes no sense, but we should still be able to set it up if with_cofactor: - others = {'cofactor': request.getfixturevalue('styrene')} + others = {"cofactor": request.getfixturevalue("styrene")} else: others = {} protocol = DummyProtocol(DummyProtocol.default_settings()) rbfe = real_molecules_network.to_rbfe_alchemical_network( - solvent=solv_comp, - protein=prot_comp, - protocol=protocol, - **others + solvent=solv_comp, protein=prot_comp, protocol=protocol, **others ) expected_names = { - 'easy_rbfe_benzene_solvent_toluene_solvent', - 'easy_rbfe_benzene_complex_toluene_complex', - 'easy_rbfe_benzene_solvent_phenol_solvent', - 'easy_rbfe_benzene_complex_phenol_complex', + "easy_rbfe_benzene_solvent_toluene_solvent", + "easy_rbfe_benzene_complex_toluene_complex", + "easy_rbfe_benzene_solvent_phenol_solvent", + "easy_rbfe_benzene_complex_phenol_complex", } - names = set(edge.name for edge in rbfe.edges) + names = {edge.name for edge in rbfe.edges} assert names == expected_names assert len(rbfe.edges) == 2 * len(real_molecules_network.edges) @@ -324,35 +325,29 @@ def test_to_rbfe_alchemical_network( compsA = edge.stateA.components compsB = edge.stateB.components - if 'solvent' in edge.name: - labels = {'solvent', 'ligand'} - elif 'complex' in edge.name: - labels = {'solvent', 'ligand', 'protein'} + if "solvent" in edge.name: + labels = {"solvent", "ligand"} + elif "complex" in edge.name: + labels = {"solvent", "ligand", "protein"} if with_cofactor: - labels.add('cofactor') + labels.add("cofactor") else: # -no-cov- - raise RuntimeError("Something went weird in testing. Unable " - f"to get leg for edge {edge}") + raise RuntimeError("Something went weird in testing. Unable " f"to get leg for edge {edge}") assert set(compsA) == labels assert set(compsB) == labels - assert compsA['ligand'] != compsB['ligand'] - assert compsA['ligand'].name == 'benzene' - assert compsA['solvent'] == compsB['solvent'] + assert compsA["ligand"] != compsB["ligand"] + assert compsA["ligand"].name == "benzene" + assert compsA["solvent"] == compsB["solvent"] # for things that might not always exist, use .get - assert compsA.get('protein') == compsB.get('protein') - assert compsA.get('cofactor') == compsB.get('cofactor') + assert compsA.get("protein") == compsB.get("protein") + assert compsA.get("cofactor") == compsB.get("cofactor") assert isinstance(edge.mapping, gufe.ComponentMapping) assert edge.mapping in real_molecules_network.edges - def test_to_rbfe_alchemical_network_autoname_false( - self, - real_molecules_network, - prot_comp, - solv_comp - ): + def test_to_rbfe_alchemical_network_autoname_false(self, real_molecules_network, prot_comp, solv_comp): rbfe = real_molecules_network.to_rbfe_alchemical_network( solvent=solv_comp, protein=prot_comp, @@ -364,12 +359,7 @@ def test_to_rbfe_alchemical_network_autoname_false( for sys in [edge.stateA, edge.stateB]: assert sys.name == "" - def test_to_rbfe_alchemical_network_autoname_true( - self, - real_molecules_network, - prot_comp, - solv_comp - ): + def test_to_rbfe_alchemical_network_autoname_true(self, real_molecules_network, prot_comp, solv_comp): rbfe = real_molecules_network.to_rbfe_alchemical_network( solvent=solv_comp, protein=prot_comp, @@ -378,33 +368,28 @@ def test_to_rbfe_alchemical_network_autoname_true( autoname_prefix="", ) expected_names = { - 'benzene_complex_toluene_complex', - 'benzene_solvent_toluene_solvent', - 'benzene_complex_phenol_complex', - 'benzene_solvent_phenol_solvent', + "benzene_complex_toluene_complex", + "benzene_solvent_toluene_solvent", + "benzene_complex_phenol_complex", + "benzene_solvent_phenol_solvent", } - names = set(edge.name for edge in rbfe.edges) + names = {edge.name for edge in rbfe.edges} assert names == expected_names @pytest.mark.xfail # method removed and on hold for now - def test_to_rhfe_alchemical_network(self, real_molecules_network, - solv_comp): + def test_to_rhfe_alchemical_network(self, real_molecules_network, solv_comp): others = {} protocol = DummyProtocol(DummyProtocol.default_settings()) - rhfe = real_molecules_network.to_rhfe_alchemical_network( - solvent=solv_comp, - protocol=protocol, - **others - ) + rhfe = real_molecules_network.to_rhfe_alchemical_network(solvent=solv_comp, protocol=protocol, **others) expected_names = { - 'easy_rhfe_benzene_vacuum_toluene_vacuum', - 'easy_rhfe_benzene_solvent_toluene_solvent', - 'easy_rhfe_benzene_vacuum_phenol_vacuum', - 'easy_rhfe_benzene_solvent_phenol_solvent', + "easy_rhfe_benzene_vacuum_toluene_vacuum", + "easy_rhfe_benzene_solvent_toluene_solvent", + "easy_rhfe_benzene_vacuum_phenol_vacuum", + "easy_rhfe_benzene_solvent_phenol_solvent", } - names = set(edge.name for edge in rhfe.edges) + names = {edge.name for edge in rhfe.edges} assert names == expected_names assert len(rhfe.edges) == 2 * len(real_molecules_network.edges) @@ -414,25 +399,24 @@ def test_to_rhfe_alchemical_network(self, real_molecules_network, compsA = edge.stateA.components compsB = edge.stateB.components - if 'vacuum' in edge.name: - labels = {'ligand'} - elif 'solvent' in edge.name: - labels = {'ligand', 'solvent'} + if "vacuum" in edge.name: + labels = {"ligand"} + elif "solvent" in edge.name: + labels = {"ligand", "solvent"} else: # -no-cov- - raise RuntimeError("Something went weird in testing. Unable " - f"to get leg for edge {edge}") + raise RuntimeError("Something went weird in testing. Unable " f"to get leg for edge {edge}") labels |= set(others) assert set(compsA) == labels assert set(compsB) == labels - assert compsA['ligand'] != compsB['ligand'] - assert compsA['ligand'].name == 'benzene' - assert compsA.get('solvent') == compsB.get('solvent') + assert compsA["ligand"] != compsB["ligand"] + assert compsA["ligand"].name == "benzene" + assert compsA.get("solvent") == compsB.get("solvent") - assert list(edge.mapping) == ['ligand'] - assert edge.mapping['ligand'] in real_molecules_network.edges + assert list(edge.mapping) == ["ligand"] + assert edge.mapping["ligand"] in real_molecules_network.edges def test_empty_ligand_network(mols): diff --git a/gufe/tests/test_ligandatommapping.py b/gufe/tests/test_ligandatommapping.py index 93809ac7..63ae076e 100644 --- a/gufe/tests/test_ligandatommapping.py +++ b/gufe/tests/test_ligandatommapping.py @@ -1,12 +1,14 @@ # This code is part of gufe and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe import importlib -import pytest -import pathlib import json +import pathlib + import numpy as np -from rdkit import Chem +import pytest from openff.units import unit +from rdkit import Chem + import gufe from gufe import LigandAtomMapping, SmallMoleculeComponent @@ -20,7 +22,7 @@ def mol_from_smiles(smiles: str) -> gufe.SmallMoleculeComponent: return gufe.SmallMoleculeComponent(m) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def simple_mapping(): """Disappearing oxygen on end @@ -28,15 +30,15 @@ def simple_mapping(): C C """ - molA = mol_from_smiles('CCO') - molB = mol_from_smiles('CC') + molA = mol_from_smiles("CCO") + molB = mol_from_smiles("CC") m = LigandAtomMapping(molA, molB, componentA_to_componentB={0: 0, 1: 1}) return m -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def other_mapping(): """Disappearing middle carbon @@ -44,8 +46,8 @@ def other_mapping(): C C """ - molA = mol_from_smiles('CCO') - molB = mol_from_smiles('CC') + molA = mol_from_smiles("CCO") + molB = mol_from_smiles("CC") m = LigandAtomMapping(molA, molB, componentA_to_componentB={0: 0, 2: 1}) @@ -54,55 +56,82 @@ def other_mapping(): @pytest.fixture def annotated_simple_mapping(simple_mapping): - mapping = LigandAtomMapping(simple_mapping.componentA, - simple_mapping.componentB, - simple_mapping.componentA_to_componentB, - annotations={'foo': 'bar'}) + mapping = LigandAtomMapping( + simple_mapping.componentA, + simple_mapping.componentB, + simple_mapping.componentA_to_componentB, + annotations={"foo": "bar"}, + ) return mapping -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_maps(): MAPS = { - 'phenol': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, - 7: 7, 8: 8, 9: 9, 10: 12, 11: 11}, - 'anisole': {0: 5, 1: 6, 2: 7, 3: 8, 4: 9, 5: 10, - 6: 11, 7: 12, 8: 13, 9: 14, 10: 2, 11: 15}} + "phenol": { + 0: 0, + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + 6: 6, + 7: 7, + 8: 8, + 9: 9, + 10: 12, + 11: 11, + }, + "anisole": { + 0: 5, + 1: 6, + 2: 7, + 3: 8, + 4: 9, + 5: 10, + 6: 11, + 7: 12, + 8: 13, + 9: 14, + 10: 2, + 11: 15, + }, + } return MAPS -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_phenol_mapping(benzene_transforms, benzene_maps): - molA = SmallMoleculeComponent(benzene_transforms['benzene'].to_rdkit()) - molB = SmallMoleculeComponent(benzene_transforms['phenol'].to_rdkit()) - m = LigandAtomMapping(molA, molB, benzene_maps['phenol']) + molA = SmallMoleculeComponent(benzene_transforms["benzene"].to_rdkit()) + molB = SmallMoleculeComponent(benzene_transforms["phenol"].to_rdkit()) + m = LigandAtomMapping(molA, molB, benzene_maps["phenol"]) return m -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_anisole_mapping(benzene_transforms, benzene_maps): - molA = SmallMoleculeComponent(benzene_transforms['benzene'].to_rdkit()) - molB = SmallMoleculeComponent(benzene_transforms['anisole'].to_rdkit()) - m = LigandAtomMapping(molA, molB, benzene_maps['anisole']) + molA = SmallMoleculeComponent(benzene_transforms["benzene"].to_rdkit()) + molB = SmallMoleculeComponent(benzene_transforms["anisole"].to_rdkit()) + m = LigandAtomMapping(molA, molB, benzene_maps["anisole"]) return m -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def atom_mapping_basic_test_files(): # a dict of {filenames.strip(mol2): SmallMoleculeComponent} for a simple # set of ligands files = {} for f in [ - '1,3,7-trimethylnaphthalene', - '1-butyl-4-methylbenzene', - '2,6-dimethylnaphthalene', - '2-methyl-6-propylnaphthalene', - '2-methylnaphthalene', - '2-naftanol', - 'methylcyclohexane', - 'toluene']: - with importlib.resources.path('gufe.tests.data.lomap_basic', - f + '.mol2') as fn: + "1,3,7-trimethylnaphthalene", + "1-butyl-4-methylbenzene", + "2,6-dimethylnaphthalene", + "2-methyl-6-propylnaphthalene", + "2-methylnaphthalene", + "2-naftanol", + "methylcyclohexane", + "toluene", + ]: + with importlib.resources.path("gufe.tests.data.lomap_basic", f + ".mol2") as fn: mol = Chem.MolFromMol2File(str(fn), removeHs=False) files[f] = SmallMoleculeComponent(mol, name=f) @@ -120,15 +149,25 @@ def test_atommapping_usage(simple_mapping): def test_mapping_inversion(benzene_phenol_mapping): assert benzene_phenol_mapping.componentB_to_componentA == { - 0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 11: 11, - 12: 10 + 0: 0, + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + 6: 6, + 7: 7, + 8: 8, + 9: 9, + 11: 11, + 12: 10, } def test_mapping_distances(benzene_phenol_mapping): d = benzene_phenol_mapping.get_distances() - ref = [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.34005502, 0.] + ref = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.34005502, 0.0] assert isinstance(d, np.ndarray) for i, r in zip(d, ref): @@ -137,15 +176,27 @@ def test_mapping_distances(benzene_phenol_mapping): def test_uniques(atom_mapping_basic_test_files): mapping = LigandAtomMapping( - componentA=atom_mapping_basic_test_files['methylcyclohexane'], - componentB=atom_mapping_basic_test_files['toluene'], - componentA_to_componentB={ - 0: 6, 1: 7, 2: 8, 3: 9, 4: 10, 5: 11, 6: 12 - } + componentA=atom_mapping_basic_test_files["methylcyclohexane"], + componentB=atom_mapping_basic_test_files["toluene"], + componentA_to_componentB={0: 6, 1: 7, 2: 8, 3: 9, 4: 10, 5: 11, 6: 12}, ) - assert list(mapping.componentA_unique) == [7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20] + assert list(mapping.componentA_unique) == [ + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + ] assert list(mapping.componentB_unique) == [0, 1, 2, 3, 4, 5, 13, 14] @@ -166,25 +217,23 @@ def test_atommapping_hash(simple_mapping, other_mapping): def test_draw_mapping_cairo(tmpdir, simple_mapping): with tmpdir.as_cwd(): - simple_mapping.draw_to_file('test.png') - filed = pathlib.Path('test.png') + simple_mapping.draw_to_file("test.png") + filed = pathlib.Path("test.png") assert filed.exists() def test_draw_mapping_svg(tmpdir, other_mapping): with tmpdir.as_cwd(): d2d = Chem.Draw.rdMolDraw2D.MolDraw2DSVG(600, 300, 300, 300) - other_mapping.draw_to_file('test.svg', d2d=d2d) - filed = pathlib.Path('test.svg') + other_mapping.draw_to_file("test.svg", d2d=d2d) + filed = pathlib.Path("test.svg") assert filed.exists() class TestLigandAtomMappingSerialization: - def test_deserialize_roundtrip(self, benzene_phenol_mapping, - benzene_anisole_mapping): + def test_deserialize_roundtrip(self, benzene_phenol_mapping, benzene_anisole_mapping): - roundtrip = LigandAtomMapping.from_dict( - benzene_phenol_mapping.to_dict()) + roundtrip = LigandAtomMapping.from_dict(benzene_phenol_mapping.to_dict()) assert roundtrip == benzene_phenol_mapping @@ -195,10 +244,10 @@ def test_deserialize_roundtrip(self, benzene_phenol_mapping, def test_file_roundtrip(self, benzene_phenol_mapping, tmpdir): with tmpdir.as_cwd(): - with open('tmpfile.json', 'w') as f: + with open("tmpfile.json", "w") as f: f.write(json.dumps(benzene_phenol_mapping.to_dict())) - with open('tmpfile.json', 'r') as f: + with open("tmpfile.json") as f: d = json.load(f) assert isinstance(d, dict) @@ -207,27 +256,26 @@ def test_file_roundtrip(self, benzene_phenol_mapping, tmpdir): assert roundtrip == benzene_phenol_mapping -def test_annotated_atommapping_hash_eq(simple_mapping, - annotated_simple_mapping): +def test_annotated_atommapping_hash_eq(simple_mapping, annotated_simple_mapping): assert annotated_simple_mapping != simple_mapping assert hash(annotated_simple_mapping) != hash(simple_mapping) def test_annotation_immutability(annotated_simple_mapping): annot1 = annotated_simple_mapping.annotations - annot1['foo'] = 'baz' + annot1["foo"] = "baz" annot2 = annotated_simple_mapping.annotations assert annot1 != annot2 - assert annot2 == {'foo': 'bar'} + assert annot2 == {"foo": "bar"} def test_with_annotations(simple_mapping, annotated_simple_mapping): - new_annot = simple_mapping.with_annotations({'foo': 'bar'}) + new_annot = simple_mapping.with_annotations({"foo": "bar"}) assert new_annot == annotated_simple_mapping def test_with_fancy_annotations(simple_mapping): - m = simple_mapping.with_annotations({'thing': 4.0 * unit.nanometer}) + m = simple_mapping.with_annotations({"thing": 4.0 * unit.nanometer}) assert m.key @@ -240,36 +288,28 @@ class TestLigandAtomMappingBoundsChecks: @pytest.fixture def molA(self): # 9 atoms - return mol_from_smiles('CCO') + return mol_from_smiles("CCO") @pytest.fixture def molB(self): # 11 atoms - return mol_from_smiles('CCC') + return mol_from_smiles("CCC") def test_too_large_A(self, molA, molB): with pytest.raises(ValueError, match="invalid index for ComponentA"): - LigandAtomMapping(componentA=molA, - componentB=molB, - componentA_to_componentB={9: 5}) + LigandAtomMapping(componentA=molA, componentB=molB, componentA_to_componentB={9: 5}) def test_too_small_A(self, molA, molB): with pytest.raises(ValueError, match="invalid index for ComponentA"): - LigandAtomMapping(componentA=molA, - componentB=molB, - componentA_to_componentB={-2: 5}) + LigandAtomMapping(componentA=molA, componentB=molB, componentA_to_componentB={-2: 5}) def test_too_large_B(self, molA, molB): with pytest.raises(ValueError, match="invalid index for ComponentB"): - LigandAtomMapping(componentA=molA, - componentB=molB, - componentA_to_componentB={5: 11}) + LigandAtomMapping(componentA=molA, componentB=molB, componentA_to_componentB={5: 11}) def test_too_small_B(self, molA, molB): with pytest.raises(ValueError, match="invalid index for ComponentB"): - LigandAtomMapping(componentA=molA, - componentB=molB, - componentA_to_componentB={5: -1}) + LigandAtomMapping(componentA=molA, componentB=molB, componentA_to_componentB={5: -1}) class TestLigandAtomMapping(GufeTokenizableTestsMixin): diff --git a/gufe/tests/test_mapping.py b/gufe/tests/test_mapping.py index cb8c7370..620c2f55 100644 --- a/gufe/tests/test_mapping.py +++ b/gufe/tests/test_mapping.py @@ -1,17 +1,21 @@ # This code is part of gufe and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe -import gufe -from gufe import AtomMapping import pytest +import gufe +from gufe import AtomMapping from .test_tokenization import GufeTokenizableTestsMixin class ExampleMapping(AtomMapping): - def __init__(self, molA: gufe.SmallMoleculeComponent, - molB: gufe.SmallMoleculeComponent, mapping): + def __init__( + self, + molA: gufe.SmallMoleculeComponent, + molB: gufe.SmallMoleculeComponent, + mapping, + ): super().__init__(molA, molB) self._mapping = mapping @@ -21,9 +25,9 @@ def _defaults(cls): def _to_dict(self): return { - 'molA': self._componentA, - 'molB': self._componentB, - 'mapping': self._mapping, + "molA": self._componentA, + "molB": self._componentB, + "mapping": self._mapping, } @classmethod @@ -37,12 +41,10 @@ def componentB_to_componentA(self): return {v: k for k, v in self._mapping} def componentA_unique(self): - return (i for i in range(self._molA.to_rdkit().GetNumAtoms()) - if i not in self._mapping) + return (i for i in range(self._molA.to_rdkit().GetNumAtoms()) if i not in self._mapping) def componentB_unique(self): - return (i for i in range(self._molB.to_rdkit().GetNumAtoms()) - if i not in self._mapping.values()) + return (i for i in range(self._molB.to_rdkit().GetNumAtoms()) if i not in self._mapping.values()) class TestMappingAbstractClass(GufeTokenizableTestsMixin): diff --git a/gufe/tests/test_mapping_visualization.py b/gufe/tests/test_mapping_visualization.py index 049ed626..756ad13a 100644 --- a/gufe/tests/test_mapping_visualization.py +++ b/gufe/tests/test_mapping_visualization.py @@ -1,18 +1,21 @@ -import pytest -from unittest import mock import inspect +from unittest import mock +import pytest from rdkit import Chem import gufe from gufe.visualization.mapping_visualization import ( - _match_elements, _get_unique_bonds_and_atoms, draw_mapping, - draw_one_molecule_mapping, draw_unhighlighted_molecule + _get_unique_bonds_and_atoms, + _match_elements, + draw_mapping, + draw_one_molecule_mapping, + draw_unhighlighted_molecule, ) # default colors currently used -_HIGHLIGHT_COLOR = (220/255, 50/255, 32/255, 1) -_CHANGED_ELEMENTS_COLOR = (0, 90/255, 181/255, 1) +_HIGHLIGHT_COLOR = (220 / 255, 50 / 255, 32 / 255, 1) +_CHANGED_ELEMENTS_COLOR = (0, 90 / 255, 181 / 255, 1) def bound_args(func, args, kwargs): @@ -38,11 +41,14 @@ def bound_args(func, args, kwargs): return bound.arguments -@pytest.mark.parametrize("at1, idx1, at2, idx2, response", [ - ["N", 0, "C", 0, False], - ["C", 0, "C", 0, True], - ["COC", 1, "NOC", 1, True], - ["COON", 2, "COC", 2, False]] +@pytest.mark.parametrize( + "at1, idx1, at2, idx2, response", + [ + ["N", 0, "C", 0, False], + ["C", 0, "C", 0, True], + ["COC", 1, "NOC", 1, True], + ["COON", 2, "COC", 2, False], + ], ) def test_match_elements(at1, idx1, at2, idx2, response): mol1 = Chem.MolFromSmiles(at1) @@ -52,56 +58,109 @@ def test_match_elements(at1, idx1, at2, idx2, response): assert retval == response -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def maps(): MAPS = { - 'phenol': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, - 7: 7, 8: 8, 9: 9, 10: 12, 11: 11}, - 'anisole': {0: 5, 1: 6, 2: 7, 3: 8, 4: 9, 5: 10, - 6: 11, 7: 12, 8: 13, 9: 14, 10: 2, 11: 15}} + "phenol": { + 0: 0, + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + 6: 6, + 7: 7, + 8: 8, + 9: 9, + 10: 12, + 11: 11, + }, + "anisole": { + 0: 5, + 1: 6, + 2: 7, + 3: 8, + 4: 9, + 5: 10, + 6: 11, + 7: 12, + 8: 13, + 9: 14, + 10: 2, + 11: 15, + }, + } return MAPS -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def benzene_phenol_mapping(benzene_transforms, maps): - mol1 = benzene_transforms['benzene'].to_rdkit() - mol2 = benzene_transforms['phenol'].to_rdkit() - mapping = maps['phenol'] + mol1 = benzene_transforms["benzene"].to_rdkit() + mol2 = benzene_transforms["phenol"].to_rdkit() + mapping = maps["phenol"] return mapping, mol1, mol2 -@pytest.mark.parametrize('molname, atoms, elems, bond_changes, bond_deletions', [ - ['phenol', {10, }, {12, }, {10, }, {12, }], - ['anisole', {0, 1, 3, 4}, {2, }, {13, }, {0, 1, 2, 3}] -]) -def test_benzene_to_phenol_uniques(molname, atoms, elems, bond_changes, bond_deletions, - benzene_transforms, maps): - mol1 = benzene_transforms['benzene'] +@pytest.mark.parametrize( + "molname, atoms, elems, bond_changes, bond_deletions", + [ + [ + "phenol", + { + 10, + }, + { + 12, + }, + { + 10, + }, + { + 12, + }, + ], + [ + "anisole", + {0, 1, 3, 4}, + { + 2, + }, + { + 13, + }, + {0, 1, 2, 3}, + ], + ], +) +def test_benzene_to_phenol_uniques(molname, atoms, elems, bond_changes, bond_deletions, benzene_transforms, maps): + mol1 = benzene_transforms["benzene"] mol2 = benzene_transforms[molname] mapping = maps[molname] - uniques = _get_unique_bonds_and_atoms(mapping, - mol1.to_rdkit(), mol2.to_rdkit()) + uniques = _get_unique_bonds_and_atoms(mapping, mol1.to_rdkit(), mol2.to_rdkit()) # The benzene perturbations don't change # no unique atoms in benzene - assert uniques['atoms'] == set() + assert uniques["atoms"] == set() # H->O - assert uniques['elements'] == {10, } + assert uniques["elements"] == { + 10, + } # One bond involved - assert uniques['bond_changes'] == {10, } + assert uniques["bond_changes"] == { + 10, + } # invert and check the molB uniques inv_map = {v: k for k, v in mapping.items()} - uniques = _get_unique_bonds_and_atoms(inv_map, - mol2.to_rdkit(), mol1.to_rdkit()) + uniques = _get_unique_bonds_and_atoms(inv_map, mol2.to_rdkit(), mol1.to_rdkit()) - assert uniques['atoms'] == atoms - assert uniques['elements'] == elems - assert uniques['bond_changes'] == bond_changes - assert uniques['bond_deletions'] == bond_deletions + assert uniques["atoms"] == atoms + assert uniques["elements"] == elems + assert uniques["bond_changes"] == bond_changes + assert uniques["bond_deletions"] == bond_deletions @mock.patch("gufe.visualization.mapping_visualization._draw_molecules", autospec=True) @@ -112,20 +171,20 @@ def test_draw_mapping(mock_func, benzene_phenol_mapping): draw_mapping(mapping, mol1, mol2) mock_func.assert_called_once() - args = bound_args(mock_func, mock_func.call_args.args, - mock_func.call_args.kwargs) - assert args['mols'] == [mol1, mol2] - assert args['atoms_list'] == [{10}, {10, 12}] - assert args['bonds_list'] == [{10}, {10, 12}] - assert args['atom_colors'] == [{10: _CHANGED_ELEMENTS_COLOR}, - {12: _CHANGED_ELEMENTS_COLOR}] - assert args['highlight_color'] == _HIGHLIGHT_COLOR - - -@pytest.mark.parametrize('inverted', [True, False]) + args = bound_args(mock_func, mock_func.call_args.args, mock_func.call_args.kwargs) + assert args["mols"] == [mol1, mol2] + assert args["atoms_list"] == [{10}, {10, 12}] + assert args["bonds_list"] == [{10}, {10, 12}] + assert args["atom_colors"] == [ + {10: _CHANGED_ELEMENTS_COLOR}, + {12: _CHANGED_ELEMENTS_COLOR}, + ] + assert args["highlight_color"] == _HIGHLIGHT_COLOR + + +@pytest.mark.parametrize("inverted", [True, False]) @mock.patch("gufe.visualization.mapping_visualization._draw_molecules", autospec=True) -def test_draw_one_molecule_mapping(mock_func, benzene_phenol_mapping, - inverted): +def test_draw_one_molecule_mapping(mock_func, benzene_phenol_mapping, inverted): # ensure that draw_one_molecule_mapping passes the desired parameters to # our internal _draw_molecules method mapping, mol1, mol2 = benzene_phenol_mapping @@ -143,30 +202,28 @@ def test_draw_one_molecule_mapping(mock_func, benzene_phenol_mapping, draw_one_molecule_mapping(mapping, mol1, mol2) mock_func.assert_called_once() - args = bound_args(mock_func, mock_func.call_args.args, - mock_func.call_args.kwargs) + args = bound_args(mock_func, mock_func.call_args.args, mock_func.call_args.kwargs) - assert args['mols'] == [mol1] - assert args['atoms_list'] == atoms_list - assert args['bonds_list'] == bonds_list - assert args['atom_colors'] == atom_colors - assert args['highlight_color'] == _HIGHLIGHT_COLOR + assert args["mols"] == [mol1] + assert args["atoms_list"] == atoms_list + assert args["bonds_list"] == bonds_list + assert args["atom_colors"] == atom_colors + assert args["highlight_color"] == _HIGHLIGHT_COLOR @mock.patch("gufe.visualization.mapping_visualization._draw_molecules", autospec=True) def test_draw_unhighlighted_molecule(mock_func, benzene_transforms): # ensure that draw_unhighlighted_molecule passes the desired parameters # to our internal _draw_molecules method - mol = benzene_transforms['benzene'].to_rdkit() + mol = benzene_transforms["benzene"].to_rdkit() draw_unhighlighted_molecule(mol) mock_func.assert_called_once() - args = bound_args(mock_func, mock_func.call_args.args, - mock_func.call_args.kwargs) - assert args['mols'] == [mol] - assert args['atoms_list'] == [[]] - assert args['bonds_list'] == [[]] - assert args['atom_colors'] == [{}] + args = bound_args(mock_func, mock_func.call_args.args, mock_func.call_args.kwargs) + assert args["mols"] == [mol] + assert args["atoms_list"] == [[]] + assert args["bonds_list"] == [[]] + assert args["atom_colors"] == [{}] # technically, we don't care what the highlight color is, so no # assertion on that @@ -186,4 +243,4 @@ def test_draw_one_molecule_integration_smoke(benzene_phenol_mapping): def test_draw_unhighlighted_molecule_integration_smoke(benzene_transforms): # integration test/smoke test to catch errors if the upstream drawing # code changes - draw_unhighlighted_molecule(benzene_transforms['benzene'].to_rdkit()) + draw_unhighlighted_molecule(benzene_transforms["benzene"].to_rdkit()) diff --git a/gufe/tests/test_models.py b/gufe/tests/test_models.py index 4fd67bf0..887cbdd7 100644 --- a/gufe/tests/test_models.py +++ b/gufe/tests/test_models.py @@ -6,14 +6,10 @@ import json -from openff.units import unit import pytest +from openff.units import unit -from gufe.settings.models import ( - OpenMMSystemGeneratorFFSettings, - Settings, - ThermoSettings, -) +from gufe.settings.models import OpenMMSystemGeneratorFFSettings, Settings, ThermoSettings def test_model_schema(): @@ -45,11 +41,16 @@ def test_default_settings(): my_settings.schema_json(indent=2) -@pytest.mark.parametrize('value,good', [ - ('parsnips', False), # shouldn't be allowed - ('hbonds', True), ('hangles', True), ('allbonds', True), # allowed options - ('HBonds', True), # check case insensitivity -]) +@pytest.mark.parametrize( + "value,good", + [ + ("parsnips", False), # shouldn't be allowed + ("hbonds", True), + ("hangles", True), + ("allbonds", True), # allowed options + ("HBonds", True), # check case insensitivity + ], +) def test_invalid_constraint(value, good): if good: s = OpenMMSystemGeneratorFFSettings(constraints=value) diff --git a/gufe/tests/test_molhashing.py b/gufe/tests/test_molhashing.py index cea50cee..33f3a7f7 100644 --- a/gufe/tests/test_molhashing.py +++ b/gufe/tests/test_molhashing.py @@ -1,9 +1,10 @@ -import pytest -from gufe.molhashing import serialize_numpy, deserialize_numpy import numpy as np +import pytest + +from gufe.molhashing import deserialize_numpy, serialize_numpy -@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int']) +@pytest.mark.parametrize("dtype", ["float32", "float64", "int"]) def test_numpy_serialization_cycle(dtype): arr = np.array([1, 2], dtype=dtype) ser = serialize_numpy(arr) diff --git a/gufe/tests/test_proteincomponent.py b/gufe/tests/test_proteincomponent.py index 63d5b406..3efce664 100644 --- a/gufe/tests/test_proteincomponent.py +++ b/gufe/tests/test_proteincomponent.py @@ -1,20 +1,19 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import os -import io import copy +import io +import os + import pytest +from numpy.testing import assert_almost_equal +from openmm import unit +from openmm.app import pdbfile from rdkit import Chem from gufe import ProteinComponent -from .test_tokenization import GufeTokenizableTestsMixin - -from openmm.app import pdbfile -from openmm import unit -from numpy.testing import assert_almost_equal - from .conftest import ALL_PDB_LOADERS +from .test_tokenization import GufeTokenizableTestsMixin @pytest.fixture @@ -34,13 +33,13 @@ def assert_same_pdb_lines(in_file_path, out_file_path): if hasattr(in_file_path, "readlines"): in_file = in_file_path else: - in_file = open(in_file_path, "r") + in_file = open(in_file_path) if isinstance(out_file_path, io.StringIO): out_file = out_file_path must_close = False else: - out_file = open(out_file_path, mode='r') + out_file = open(out_file_path) must_close = True in_lines = in_file.readlines() @@ -49,10 +48,8 @@ def assert_same_pdb_lines(in_file_path, out_file_path): if must_close: out_file.close() - in_lines = [l for l in in_lines - if not l.startswith(('REMARK', 'CRYST', '# Created with'))] - out_lines = [l for l in out_lines - if not l.startswith(('REMARK', 'CRYST', '# Created with'))] + in_lines = [l for l in in_lines if not l.startswith(("REMARK", "CRYST", "# Created with"))] + out_lines = [l for l in out_lines if not l.startswith(("REMARK", "CRYST", "# Created with"))] assert in_lines == out_lines @@ -94,7 +91,7 @@ def instance(self, PDB_181L_path): return self.cls.from_pdb_file(PDB_181L_path, name="Steve") # From - @pytest.mark.parametrize('in_pdb_path', ALL_PDB_LOADERS.keys()) + @pytest.mark.parametrize("in_pdb_path", ALL_PDB_LOADERS.keys()) def test_from_pdb_file(self, in_pdb_path): in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]() p = self.cls.from_pdb_file(in_pdb_io, name="Steve") @@ -125,8 +122,7 @@ def test_to_rdkit(self, PDB_181L_path): assert isinstance(rdkitmol, Chem.Mol) assert rdkitmol.GetNumAtoms() == 2639 - def _test_file_output(self, input_path, output_path, input_type, - output_func): + def _test_file_output(self, input_path, output_path, input_type, output_func): if input_type == "filename": inp = str(output_path) elif input_type == "Path": @@ -134,7 +130,7 @@ def _test_file_output(self, input_path, output_path, input_type, elif input_type == "StringIO": inp = io.StringIO() elif input_type == "TextIOWrapper": - inp = open(output_path, mode='w') + inp = open(output_path, mode="w") output_func(inp) @@ -145,13 +141,10 @@ def _test_file_output(self, input_path, output_path, input_type, if input_type == "TextIOWrapper": inp.close() - assert_same_pdb_lines(in_file_path=str(input_path), - out_file_path=output_path) + assert_same_pdb_lines(in_file_path=str(input_path), out_file_path=output_path) - @pytest.mark.parametrize('input_type', ['filename', 'Path', 'StringIO', - 'TextIOWrapper']) - def test_to_pdbx_file(self, PDBx_181L_openMMClean_path, tmp_path, - input_type): + @pytest.mark.parametrize("input_type", ["filename", "Path", "StringIO", "TextIOWrapper"]) + def test_to_pdbx_file(self, PDBx_181L_openMMClean_path, tmp_path, input_type): p = self.cls.from_pdbx_file(str(PDBx_181L_openMMClean_path), name="Bob") out_file_name = "tmp_181L_pdbx.cif" out_file = tmp_path / out_file_name @@ -160,28 +153,26 @@ def test_to_pdbx_file(self, PDBx_181L_openMMClean_path, tmp_path, input_path=PDBx_181L_openMMClean_path, output_path=out_file, input_type=input_type, - output_func=p.to_pdbx_file + output_func=p.to_pdbx_file, ) - @pytest.mark.parametrize('input_type', ['filename', 'Path', 'StringIO', - 'TextIOWrapper']) - def test_to_pdb_input_types(self, PDB_181L_OpenMMClean_path, tmp_path, - input_type): + @pytest.mark.parametrize("input_type", ["filename", "Path", "StringIO", "TextIOWrapper"]) + def test_to_pdb_input_types(self, PDB_181L_OpenMMClean_path, tmp_path, input_type): p = self.cls.from_pdb_file(str(PDB_181L_OpenMMClean_path), name="Bob") self._test_file_output( input_path=PDB_181L_OpenMMClean_path, output_path=tmp_path / "tmp_181L.pdb", input_type=input_type, - output_func=p.to_pdb_file + output_func=p.to_pdb_file, ) - @pytest.mark.parametrize('in_pdb_path', ALL_PDB_LOADERS.keys()) + @pytest.mark.parametrize("in_pdb_path", ALL_PDB_LOADERS.keys()) def test_to_pdb_round_trip(self, in_pdb_path, tmp_path): in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]() p = self.cls.from_pdb_file(in_pdb_io, name="Wuff") - out_file_name = "tmp_"+in_pdb_path+".pdb" + out_file_name = "tmp_" + in_pdb_path + ".pdb" out_file = tmp_path / out_file_name p.to_pdb_file(str(out_file)) @@ -190,7 +181,7 @@ def test_to_pdb_round_trip(self, in_pdb_path, tmp_path): # generate openMM reference file: openmm_pdb = pdbfile.PDBFile(ref_in_pdb_io) - out_ref_file_name = "tmp_"+in_pdb_path+"_openmm_ref.pdb" + out_ref_file_name = "tmp_" + in_pdb_path + "_openmm_ref.pdb" out_ref_file = tmp_path / out_ref_file_name pdbfile.PDBFile.writeFile(openmm_pdb.topology, openmm_pdb.positions, file=open(str(out_ref_file), "w")) @@ -213,10 +204,10 @@ def test_dummy_from_dict(self, PDB_181L_OpenMMClean_path): assert p == p2 # parametrize - @pytest.mark.parametrize('in_pdb_path', ALL_PDB_LOADERS.keys()) + @pytest.mark.parametrize("in_pdb_path", ALL_PDB_LOADERS.keys()) def test_to_openmm_positions(self, in_pdb_path): in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]() - ref_in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]() + ref_in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]() openmm_pdb = pdbfile.PDBFile(ref_in_pdb_io) openmm_pos = openmm_pdb.positions @@ -230,10 +221,10 @@ def test_to_openmm_positions(self, in_pdb_path): assert_almost_equal(actual=v1, desired=v2, decimal=6) # parametrize - @pytest.mark.parametrize('in_pdb_path', ALL_PDB_LOADERS.keys()) + @pytest.mark.parametrize("in_pdb_path", ALL_PDB_LOADERS.keys()) def test_to_openmm_topology(self, in_pdb_path): - in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]() - ref_in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]() + in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]() + ref_in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]() openmm_pdb = pdbfile.PDBFile(ref_in_pdb_io) openmm_top = openmm_pdb.topology diff --git a/gufe/tests/test_protocol.py b/gufe/tests/test_protocol.py index 4c44418f..50fb8c0f 100644 --- a/gufe/tests/test_protocol.py +++ b/gufe/tests/test_protocol.py @@ -2,29 +2,29 @@ # For details, see https://github.com/OpenFreeEnergy/gufe import datetime import itertools -from openff.units import unit -from typing import Optional, Iterable, List, Dict, Any, Union -from collections import defaultdict import pathlib +from collections import defaultdict +from collections.abc import Iterable, Sized +from typing import Any, Dict, List, Optional, Union -import pytest import networkx as nx import numpy as np +import pytest +from openff.units import unit import gufe +from gufe import settings from gufe.chemicalsystem import ChemicalSystem from gufe.mapping import ComponentMapping -from gufe import settings from gufe.protocols import ( Protocol, ProtocolDAG, - ProtocolUnit, - ProtocolResult, ProtocolDAGResult, - ProtocolUnitResult, + ProtocolResult, + ProtocolUnit, ProtocolUnitFailure, + ProtocolUnitResult, ) - from gufe.protocols.protocoldag import execute_DAG from .test_tokenization import GufeTokenizableTestsMixin @@ -41,15 +41,15 @@ def _execute(ctx, *, settings, stateA, stateB, mapping, start, **inputs): class SimulationUnit(ProtocolUnit): @staticmethod def _execute(ctx, *, initialization, **inputs): - output = [initialization.outputs['log']] + output = [initialization.outputs["log"]] output.append("running_md_{}".format(inputs["window"])) return dict( log=output, window=inputs["window"], - key_result=(100 - (inputs["window"] - 10)**2), + key_result=(100 - (inputs["window"] - 10) ** 2), scratch=ctx.scratch, - shared=ctx.shared + shared=ctx.shared, ) @@ -57,15 +57,12 @@ class FinishUnit(ProtocolUnit): @staticmethod def _execute(ctx, *, simulations, **inputs): - output = [s.outputs['log'] for s in simulations] + output = [s.outputs["log"] for s in simulations] output.append("assembling_results") - key_results = {str(s.inputs['window']): s.outputs['key_result'] for s in simulations} + key_results = {str(s.inputs["window"]): s.outputs["key_result"] for s in simulations} - return dict( - log=output, - key_results=key_results - ) + return dict(log=output, key_results=key_results) class DummySpecificSettings(settings.Settings): @@ -78,7 +75,7 @@ def get_estimate(self): # product of neighboring simulation window `key_result`s dgs = [] - for sample in self.data['key_results']: + for sample in self.data["key_results"]: windows = sorted(sample.keys()) dg = 0 for i, j in zip(windows[:-1], windows[1:]): @@ -88,8 +85,7 @@ def get_estimate(self): return np.mean(dg) - def get_uncertainty(self): - ... + def get_uncertainty(self): ... class DummyProtocol(Protocol): @@ -112,16 +108,16 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]]=None, + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None, extends: Optional[ProtocolDAGResult] = None, - ) -> List[ProtocolUnit]: + ) -> list[ProtocolUnit]: # rip apart `extends` if needed to feed into `InitializeUnit` if extends is not None: # this is an example; wouldn't want to pass in whole ProtocolDAGResult into # any ProtocolUnits below, since this could create dependency hell; # instead, extract what's needed from it for starting point here - starting_point = extends.protocol_unit_results[-1].outputs['key_results'] + starting_point = extends.protocol_unit_results[-1].outputs["key_results"] else: starting_point = None @@ -133,10 +129,11 @@ def _create( stateB=stateB, mapping=mapping, start=starting_point, - some_dict={'a': 2, 'b': 12}) + some_dict={"a": 2, "b": 12}, + ) # create several units that would each run an independent simulation - simulations: List[ProtocolUnit] = [ + simulations: list[ProtocolUnit] = [ SimulationUnit(settings=self.settings, name=f"sim {i}", window=i, initialization=alpha) for i in range(self.settings.n_repeats) # type: ignore ] @@ -147,16 +144,14 @@ def _create( # return all `ProtocolUnit`s we created return [alpha, *simulations, omega] - def _gather( - self, protocol_dag_results: Iterable[ProtocolDAGResult] - ) -> Dict[str, Any]: + def _gather(self, protocol_dag_results: Iterable[ProtocolDAGResult]) -> dict[str, Any]: outputs = defaultdict(list) for pdr in protocol_dag_results: for pur in pdr.terminal_protocol_unit_results: if pur.name == "the end": - outputs['logs'].append(pur.outputs['log']) - outputs['key_results'].append(pur.outputs['key_results']) + outputs["logs"].append(pur.outputs["log"]) + outputs["key_results"].append(pur.outputs["key_results"]) return dict(outputs) @@ -164,7 +159,7 @@ def _gather( class BrokenSimulationUnit(SimulationUnit): @staticmethod def _execute(ctx, **inputs): - raise ValueError("I have failed my mission", {'data': 'lol'}) + raise ValueError("I have failed my mission", {"data": "lol"}) class BrokenProtocol(DummyProtocol): @@ -172,7 +167,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]]=None, + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None, extends: Optional[ProtocolDAGResult] = None, ) -> list[ProtocolUnit]: @@ -186,12 +181,19 @@ def _create( ) # create several units that would each run an independent simulation - simulations: List[ProtocolUnit] = [ + simulations: list[ProtocolUnit] = [ SimulationUnit(settings=self.settings, name=f"sim {i}", window=i, initialization=alpha) for i in range(21) ] # introduce a broken ProtocolUnit - simulations.append(BrokenSimulationUnit(settings=self.settings, window=21, name="problem child", initialization=alpha)) + simulations.append( + BrokenSimulationUnit( + settings=self.settings, + window=21, + name="problem child", + initialization=alpha, + ) + ) # gather results from simulations, finalize outputs omega = FinishUnit(settings=self.settings, name="the end", simulations=simulations) @@ -213,18 +215,19 @@ def instance(self): def protocol_dag(self, solvated_ligand, vacuum_ligand, tmpdir): protocol = DummyProtocol(settings=DummyProtocol.default_settings()) dag = protocol.create( - stateA=solvated_ligand, stateB=vacuum_ligand, name="a dummy run", + stateA=solvated_ligand, + stateB=vacuum_ligand, + name="a dummy run", mapping=None, ) with tmpdir.as_cwd(): - shared = pathlib.Path('shared') + shared = pathlib.Path("shared") shared.mkdir(parents=True) - - scratch = pathlib.Path('scratch') + + scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) - dagresult: ProtocolDAGResult = execute_DAG( - dag, shared_basedir=shared, scratch_basedir=scratch) + dagresult: ProtocolDAGResult = execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch) return protocol, dag, dagresult @@ -232,18 +235,21 @@ def protocol_dag(self, solvated_ligand, vacuum_ligand, tmpdir): def protocol_dag_broken(self, solvated_ligand, vacuum_ligand, tmpdir): protocol = BrokenProtocol(settings=BrokenProtocol.default_settings()) dag = protocol.create( - stateA=solvated_ligand, stateB=vacuum_ligand, name="a broken dummy run", + stateA=solvated_ligand, + stateB=vacuum_ligand, + name="a broken dummy run", mapping=None, ) with tmpdir.as_cwd(): - shared = pathlib.Path('shared') + shared = pathlib.Path("shared") shared.mkdir(parents=True) - - scratch = pathlib.Path('scratch') + + scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) dagfailure: ProtocolDAGResult = execute_DAG( - dag, shared_basedir=shared, scratch_basedir=scratch, raise_error=False) + dag, shared_basedir=shared, scratch_basedir=scratch, raise_error=False + ) return protocol, dag, dagfailure @@ -257,24 +263,24 @@ def test_dag_execute(self, protocol_dag): assert finishresult.name == "the end" # gather SimulationUnits - simulationresults = [dagresult.unit_to_result(pu) - for pu in dagresult.protocol_units - if isinstance(pu, SimulationUnit)] + simulationresults = [ + dagresult.unit_to_result(pu) for pu in dagresult.protocol_units if isinstance(pu, SimulationUnit) + ] # check that we have dependency information in results - assert set(finishresult.inputs['simulations']) == {u for u in simulationresults} + assert set(finishresult.inputs["simulations"]) == {u for u in simulationresults} # check that we have as many units as we expect in resulting graph assert len(dagresult.graph) == 23 - + # check that each simulation has its own shared directory - assert len(set(i.outputs['shared'] for i in simulationresults)) == len(simulationresults) + assert len({i.outputs["shared"] for i in simulationresults}) == len(simulationresults) # check that each simulation has its own scratch directory - assert len(set(i.outputs['scratch'] for i in simulationresults)) == len(simulationresults) + assert len({i.outputs["scratch"] for i in simulationresults}) == len(simulationresults) # check that shared and scratch not the same for each simulation - assert all([i.outputs['scratch'] != i.outputs['shared'] for i in simulationresults]) + assert all([i.outputs["scratch"] != i.outputs["shared"] for i in simulationresults]) def test_terminal_units(self, protocol_dag): prot, dag, res = protocol_dag @@ -283,7 +289,7 @@ def test_terminal_units(self, protocol_dag): assert len(finals) == 1 assert isinstance(finals[0], ProtocolUnitResult) - assert finals[0].name == 'the end' + assert finals[0].name == "the end" def test_dag_execute_failure(self, protocol_dag_broken): protocol, dag, dagfailure = protocol_dag_broken @@ -297,7 +303,7 @@ def test_dag_execute_failure(self, protocol_dag_broken): assert failed_units[0].name == "problem child" # parse exception arguments - assert failed_units[0].exception[1][1]['data'] == "lol" + assert failed_units[0].exception[1][1]["data"] == "lol" assert isinstance(failed_units[0], ProtocolUnitFailure) succeeded_units = dagfailure.protocol_unit_results @@ -307,18 +313,25 @@ def test_dag_execute_failure(self, protocol_dag_broken): def test_dag_execute_failure_raise_error(self, solvated_ligand, vacuum_ligand, tmpdir): protocol = BrokenProtocol(settings=BrokenProtocol.default_settings()) dag = protocol.create( - stateA=solvated_ligand, stateB=vacuum_ligand, name="a broken dummy run", + stateA=solvated_ligand, + stateB=vacuum_ligand, + name="a broken dummy run", mapping=None, ) with tmpdir.as_cwd(): - shared = pathlib.Path('shared') + shared = pathlib.Path("shared") shared.mkdir(parents=True) - - scratch = pathlib.Path('scratch') + + scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) with pytest.raises(ValueError, match="I have failed my mission"): - execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch, raise_error=True) + execute_DAG( + dag, + shared_basedir=shared, + scratch_basedir=scratch, + raise_error=True, + ) def test_create_execute_gather(self, protocol_dag): protocol, dag, dagresult = protocol_dag @@ -328,28 +341,51 @@ def test_create_execute_gather(self, protocol_dag): # gather aggregated results of interest protocolresult = protocol.gather([dagresult]) - assert len(protocolresult.data['logs']) == 1 - assert len(protocolresult.data['logs'][0]) == 21 + 1 + assert protocolresult.n_protocol_dag_results == 1 + assert len(protocolresult.data["logs"]) == 1 + assert len(protocolresult.data["logs"][0]) == 21 + 1 assert protocolresult.get_estimate() == 95500.0 + def test_gather_infinite_iterable_guardrail(self, protocol_dag): + protocol, dag, dagresult = protocol_dag + + assert dagresult.ok() + + # we want an infinite generator, but one that would actually stop early in case + # the guardrail doesn't work, but the type system doesn't know that + def infinite_generator(): + while True: + yield dag + break + + gen = infinite_generator() + assert isinstance(gen, Iterable) + assert not isinstance(gen, Sized) + + with pytest.raises(ValueError, match="`protocol_dag_results` must implement `__len__`"): + protocol.gather(infinite_generator()) + def test_deprecation_warning_on_dict_mapping(self, instance, vacuum_ligand, solvated_ligand): - lig = solvated_ligand.components['ligand'] + lig = solvated_ligand.components["ligand"] + mapping = gufe.LigandAtomMapping(lig, lig, componentA_to_componentB={}) - with pytest.warns(DeprecationWarning, - match="mapping input as a dict is deprecated"): - instance.create(stateA=solvated_ligand, stateB=vacuum_ligand, - mapping={'ligand': mapping}) + with pytest.warns(DeprecationWarning, match="mapping input as a dict is deprecated"): + instance.create( + stateA=solvated_ligand, + stateB=vacuum_ligand, + mapping={"ligand": mapping}, + ) class ProtocolDAGTestsMixin(GufeTokenizableTestsMixin): - + def test_protocol_units(self, instance): # ensure that protocol units are given in-order based on DAG # dependencies checked = [] for pu in instance.protocol_units: - assert set(pu.dependencies).issubset(checked) + assert set(pu.dependencies).issubset(checked) checked.append(pu) def test_graph(self, instance): @@ -394,7 +430,7 @@ def instance(self, protocol_dag): def test_protocol_unit_results(self, instance: ProtocolDAGResult): # ensure that protocolunitresults are given in-order based on DAG # dependencies - checked: List[Union[ProtocolUnitResult, ProtocolUnitFailure]] = [] + checked: list[Union[ProtocolUnitResult, ProtocolUnitFailure]] = [] for pur in instance.protocol_unit_results: assert set(pur.dependencies).issubset(checked) checked.append(pur) @@ -446,8 +482,7 @@ def test_protocol_unit_failures(self, instance: ProtocolDAGResult): # protocolunitfailures should have no dependents for puf in instance.protocol_unit_failures: - assert all([puf not in pu.dependencies - for pu in instance.protocol_unit_results]) + assert all([puf not in pu.dependencies for pu in instance.protocol_unit_results]) for node in instance.result_graph.nodes: with pytest.raises(KeyError): @@ -464,10 +499,10 @@ def test_protocol_unit_failure_traceback(self, instance: ProtocolDAGResult): class TestProtocolUnit(GufeTokenizableTestsMixin): cls = SimulationUnit repr = None - + @pytest.fixture def instance(self, vacuum_ligand, solvated_ligand): - + # convert protocol inputs into starting points for independent simulations alpha = InitializeUnit( name="the beginning", @@ -476,9 +511,9 @@ def instance(self, vacuum_ligand, solvated_ligand): stateB=solvated_ligand, mapping=None, start=None, - some_dict={'a': 2, 'b': 12}, + some_dict={"a": 2, "b": 12}, ) - + return SimulationUnit(name=f"simulation", initialization=alpha) def test_key_stable(self, instance): @@ -489,20 +524,21 @@ def test_key_stable(self, instance): class NoDepUnit(ProtocolUnit): @staticmethod - def _execute(ctx, **inputs) -> Dict[str, Any]: - return {'local': inputs['val'] ** 2} + def _execute(ctx, **inputs) -> dict[str, Any]: + return {"local": inputs["val"] ** 2} class NoDepResults(ProtocolResult): def get_estimate(self): - return sum(self.data['vals']) + return sum(self.data["vals"]) def get_uncertainty(self): - return len(self.data['vals']) + return len(self.data["vals"]) class NoDepsProtocol(Protocol): """A protocol without dependencies""" + result_cls = NoDepResults @classmethod @@ -514,20 +550,21 @@ def _default_settings(cls): return settings.Settings.get_defaults() def _create( - self, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None, - extends: Optional[ProtocolDAGResult] = None, - ) -> List[ProtocolUnit]: - return [NoDepUnit(settings=self.settings, - val=i) - for i in range(3)] + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None, + extends: Optional[ProtocolDAGResult] = None, + ) -> list[ProtocolUnit]: + return [NoDepUnit(settings=self.settings, val=i) for i in range(3)] def _gather(self, dag_results): return { - 'vals': list(itertools.chain.from_iterable( - (d.outputs['local'] for d in dag.protocol_unit_results) for dag in dag_results)), + "vals": list( + itertools.chain.from_iterable( + (d.outputs["local"] for d in dag.protocol_unit_results) for dag in dag_results + ) + ), } @@ -539,19 +576,20 @@ def protocol(self): @pytest.fixture() def dag(self, protocol): return protocol.create( - stateA=ChemicalSystem(components={'solvent': gufe.SolventComponent(positive_ion='Na')}), - stateB=ChemicalSystem(components={'solvent': gufe.SolventComponent(positive_ion='Li')}), - mapping=None) + stateA=ChemicalSystem(components={"solvent": gufe.SolventComponent(positive_ion="Na")}), + stateB=ChemicalSystem(components={"solvent": gufe.SolventComponent(positive_ion="Li")}), + mapping=None, + ) def test_create(self, dag): assert len(dag.protocol_units) == 3 def test_gather(self, protocol, dag, tmpdir): with tmpdir.as_cwd(): - shared = pathlib.Path('shared') + shared = pathlib.Path("shared") shared.mkdir(parents=True) - - scratch = pathlib.Path('scratch') + + scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) dag_result = execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch) @@ -565,10 +603,10 @@ def test_gather(self, protocol, dag, tmpdir): def test_terminal_units(self, protocol, dag, tmpdir): with tmpdir.as_cwd(): - shared = pathlib.Path('shared') + shared = pathlib.Path("shared") shared.mkdir(parents=True) - - scratch = pathlib.Path('scratch') + + scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) # we have no dependencies, so this should be all three Unit results @@ -581,6 +619,7 @@ def test_terminal_units(self, protocol, dag, tmpdir): class TestProtocolDAGResult: """tests for combinations of failures and successes in a DAGResult""" + @staticmethod @pytest.fixture() def units() -> list[ProtocolUnit]: @@ -593,10 +632,16 @@ def successes(units) -> list[ProtocolUnitResult]: t1 = datetime.datetime.now() t2 = datetime.datetime.now() - return [ProtocolUnitResult(source_key=u.key, inputs=u.inputs, - outputs={'result': i ** 2}, - start_time=t1, end_time=t2) - for i, u in enumerate(units)] + return [ + ProtocolUnitResult( + source_key=u.key, + inputs=u.inputs, + outputs={"result": i**2}, + start_time=t1, + end_time=t2, + ) + for i, u in enumerate(units) + ] @staticmethod @pytest.fixture() @@ -605,13 +650,21 @@ def failures(units) -> list[list[ProtocolUnitFailure]]: t1 = datetime.datetime.now() t2 = datetime.datetime.now() - return [[ProtocolUnitFailure(source_key=u.key, inputs=u.inputs, - outputs=dict(), - exception=('ValueError', "Didn't feel like it"), - traceback='foo', - start_time=t1, end_time=t2) - for i in range(2)] - for u in units] + return [ + [ + ProtocolUnitFailure( + source_key=u.key, + inputs=u.inputs, + outputs=dict(), + exception=("ValueError", "Didn't feel like it"), + traceback="foo", + start_time=t1, + end_time=t2, + ) + for i in range(2) + ] + for u in units + ] def test_all_successes(self, units, successes): dagresult = ProtocolDAGResult( @@ -679,22 +732,26 @@ def test_foreign_objects(self, units, successes): def test_execute_DAG_retries(solvated_ligand, vacuum_ligand, tmpdir): protocol = BrokenProtocol(settings=BrokenProtocol.default_settings()) dag = protocol.create( - stateA=solvated_ligand, stateB=vacuum_ligand, mapping=None, + stateA=solvated_ligand, + stateB=vacuum_ligand, + mapping=None, ) with tmpdir.as_cwd(): - shared = pathlib.Path('shared') + shared = pathlib.Path("shared") shared.mkdir(parents=True) - scratch = pathlib.Path('scratch') + scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) - r = execute_DAG(dag, - shared_basedir=shared, - scratch_basedir=scratch, - keep_shared=True, - keep_scratch=True, - raise_error=False, - n_retries=3) + r = execute_DAG( + dag, + shared_basedir=shared, + scratch_basedir=scratch, + keep_shared=True, + keep_scratch=True, + raise_error=False, + n_retries=3, + ) assert not r.ok() @@ -708,26 +765,31 @@ def test_execute_DAG_retries(solvated_ligand, vacuum_ligand, tmpdir): # final failure assert number_unit_results == number_dirs == 26 + def test_execute_DAG_bad_nretries(solvated_ligand, vacuum_ligand, tmpdir): protocol = BrokenProtocol(settings=BrokenProtocol.default_settings()) dag = protocol.create( - stateA=solvated_ligand, stateB=vacuum_ligand, mapping=None, + stateA=solvated_ligand, + stateB=vacuum_ligand, + mapping=None, ) with tmpdir.as_cwd(): - shared = pathlib.Path('shared') + shared = pathlib.Path("shared") shared.mkdir(parents=True) - scratch = pathlib.Path('scratch') + scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) with pytest.raises(ValueError): - r = execute_DAG(dag, - shared_basedir=shared, - scratch_basedir=scratch, - keep_shared=True, - keep_scratch=True, - raise_error=False, - n_retries=-1) + r = execute_DAG( + dag, + shared_basedir=shared, + scratch_basedir=scratch, + keep_shared=True, + keep_scratch=True, + raise_error=False, + n_retries=-1, + ) def test_settings_readonly(): diff --git a/gufe/tests/test_protocoldag.py b/gufe/tests/test_protocoldag.py index 9454eda0..5c898492 100644 --- a/gufe/tests/test_protocoldag.py +++ b/gufe/tests/test_protocoldag.py @@ -2,6 +2,7 @@ # For details, see https://github.com/OpenFreeEnergy/gufe import os import pathlib + import pytest from openff.units import unit @@ -12,15 +13,15 @@ class WriterUnit(gufe.ProtocolUnit): @staticmethod def _execute(ctx, **inputs): - my_id = inputs['identity'] + my_id = inputs["identity"] - with open(os.path.join(ctx.shared, f'unit_{my_id}_shared.txt'), 'w') as out: - out.write(f'unit {my_id} existed!\n') - with open(os.path.join(ctx.scratch, f'unit_{my_id}_scratch.txt'), 'w') as out: - out.write(f'unit {my_id} was here\n') + with open(os.path.join(ctx.shared, f"unit_{my_id}_shared.txt"), "w") as out: + out.write(f"unit {my_id} existed!\n") + with open(os.path.join(ctx.scratch, f"unit_{my_id}_scratch.txt"), "w") as out: + out.write(f"unit {my_id} was here\n") return { - 'log': 'finished', + "log": "finished", } @@ -30,11 +31,9 @@ class WriterSettings(gufe.settings.Settings): class WriterProtocolResult(gufe.ProtocolResult): - def get_estimate(self): - ... + def get_estimate(self): ... - def get_uncertainty(self): - ... + def get_uncertainty(self): ... class WriterProtocol(gufe.Protocol): @@ -45,18 +44,15 @@ def _default_settings(cls): return WriterSettings( thermo_settings=gufe.settings.ThermoSettings(temperature=298 * unit.kelvin), forcefield_settings=gufe.settings.OpenMMSystemGeneratorFFSettings(), - n_repeats=4 + n_repeats=4, ) - @classmethod def _defaults(cls): return {} - def _create(self, stateA, stateB, mapping, extends=None) -> list[gufe.ProtocolUnit]: - return [ - WriterUnit(identity=i) for i in range(self.settings.n_repeats) # type: ignore - ] + def _create(self, stateA, stateB, mapping, extends=None) -> list[gufe.ProtocolUnit]: + return [WriterUnit(identity=i) for i in range(self.settings.n_repeats)] # type: ignore def _gather(self, results): return {} @@ -72,34 +68,36 @@ def writefile_dag(): return p.create(stateA=s1, stateB=s2, mapping=[]) -@pytest.mark.parametrize('keep_shared', [False, True]) -@pytest.mark.parametrize('keep_scratch', [False, True]) +@pytest.mark.parametrize("keep_shared", [False, True]) +@pytest.mark.parametrize("keep_scratch", [False, True]) def test_execute_dag(tmpdir, keep_shared, keep_scratch, writefile_dag): with tmpdir.as_cwd(): - shared = pathlib.Path('shared') + shared = pathlib.Path("shared") shared.mkdir(parents=True) - scratch = pathlib.Path('scratch') + scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) - + # run dag - execute_DAG(writefile_dag, - shared_basedir=shared, - scratch_basedir=scratch, - keep_shared=keep_shared, - keep_scratch=keep_scratch) - + execute_DAG( + writefile_dag, + shared_basedir=shared, + scratch_basedir=scratch, + keep_shared=keep_shared, + keep_scratch=keep_scratch, + ) + # check outputs are as expected # will have produced 4 files in scratch and shared directory for pu in writefile_dag.protocol_units: - identity = pu.inputs['identity'] - shared_file = os.path.join(shared, - f'shared_{str(pu.key)}_attempt_0', - f'unit_{identity}_shared.txt') - scratch_file = os.path.join(scratch, - f'scratch_{str(pu.key)}_attempt_0', - f'unit_{identity}_scratch.txt') + identity = pu.inputs["identity"] + shared_file = os.path.join(shared, f"shared_{str(pu.key)}_attempt_0", f"unit_{identity}_shared.txt") + scratch_file = os.path.join( + scratch, + f"scratch_{str(pu.key)}_attempt_0", + f"unit_{identity}_scratch.txt", + ) if keep_shared: assert os.path.exists(shared_file) else: diff --git a/gufe/tests/test_protocolresult.py b/gufe/tests/test_protocolresult.py index 4746906b..86ee0c1c 100644 --- a/gufe/tests/test_protocolresult.py +++ b/gufe/tests/test_protocolresult.py @@ -1,18 +1,19 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe -from openff.units import unit import pytest +from openff.units import unit import gufe + from .test_tokenization import GufeTokenizableTestsMixin class DummyProtocolResult(gufe.ProtocolResult): def get_estimate(self): - return self.data['estimate'] + return self.data["estimate"] def get_uncertainty(self): - return self.data['uncertainty'] + return self.data["uncertainty"] class TestProtocolResult(GufeTokenizableTestsMixin): @@ -25,3 +26,31 @@ def instance(self): estimate=4.2 * unit.kilojoule_per_mole, uncertainty=0.2 * unit.kilojoule_per_mole, ) + + def test_protocolresult_get_estimate(self, instance): + assert instance.get_estimate() == 4.2 * unit.kilojoule_per_mole + + def test_protocolresult_get_uncertainty(self, instance): + assert instance.get_uncertainty() == 0.2 * unit.kilojoule_per_mole + + def test_protocolresult_default_n_protocol_dag_results(self, instance): + assert instance.n_protocol_dag_results == 0 + + def test_protocol_result_from_dict_missing_n_protocol_dag_results(self, instance): + protocol_result_dict_form = instance.to_dict() + assert DummyProtocolResult.from_dict(protocol_result_dict_form) == instance + del protocol_result_dict_form["n_protocol_dag_results"] + assert DummyProtocolResult.from_dict(protocol_result_dict_form) == instance + + @pytest.mark.parametrize("arg, expected", [(0, 0), (1, 1), (-1, ValueError)]) + def test_protocolresult_get_n_protocol_dag_results_args(self, arg, expected): + try: + protocol_result = DummyProtocolResult( + n_protocol_dag_results=arg, + estimate=4.2 * unit.kilojoule_per_mole, + uncertainty=0.2 * unit.kilojoule_per_mole, + ) + assert protocol_result.n_protocol_dag_results == expected + except ValueError: + if expected is not ValueError: + raise AssertionError() diff --git a/gufe/tests/test_protocolunit.py b/gufe/tests/test_protocolunit.py index 9896f856..51a37a47 100644 --- a/gufe/tests/test_protocolunit.py +++ b/gufe/tests/test_protocolunit.py @@ -1,8 +1,9 @@ import string -import pytest from pathlib import Path -from gufe.protocols.protocolunit import ProtocolUnit, Context, ProtocolUnitFailure, ProtocolUnitResult +import pytest + +from gufe.protocols.protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult from gufe.tests.test_tokenization import GufeTokenizableTestsMixin @@ -51,27 +52,27 @@ def test_execute(self, tmpdir): unit = DummyUnit() - shared = Path('shared') / str(unit.key) + shared = Path("shared") / str(unit.key) shared.mkdir(parents=True) - scratch = Path('scratch') / str(unit.key) + scratch = Path("scratch") / str(unit.key) scratch.mkdir(parents=True) ctx = Context(shared=shared, scratch=scratch) - + u: ProtocolUnitFailure = unit.execute(context=ctx, an_input=3) assert u.exception[0] == "ValueError" unit = DummyUnit() - shared = Path('shared') / str(unit.key) + shared = Path("shared") / str(unit.key) shared.mkdir(parents=True) - scratch = Path('scratch') / str(unit.key) + scratch = Path("scratch") / str(unit.key) scratch.mkdir(parents=True) ctx = Context(shared=shared, scratch=scratch) - + # now try actually letting the error raise on execute with pytest.raises(ValueError, match="should always be 2"): unit.execute(context=ctx, raise_error=True, an_input=3) @@ -81,23 +82,23 @@ def test_execute_KeyboardInterrupt(self, tmpdir): unit = DummyKeyboardInterruptUnit() - shared = Path('shared') / str(unit.key) + shared = Path("shared") / str(unit.key) shared.mkdir(parents=True) - scratch = Path('scratch') / str(unit.key) + scratch = Path("scratch") / str(unit.key) scratch.mkdir(parents=True) ctx = Context(shared=shared, scratch=scratch) - + with pytest.raises(KeyboardInterrupt): unit.execute(context=ctx, an_input=3) - + u: ProtocolUnitResult = unit.execute(context=ctx, an_input=2) - assert u.outputs == {'foo': 'bar'} + assert u.outputs == {"foo": "bar"} def test_normalize(self, dummy_unit): thingy = dummy_unit.key - assert thingy.startswith('DummyUnit-') - assert all(t in string.hexdigits for t in thingy.partition('-')[-1]) + assert thingy.startswith("DummyUnit-") + assert all(t in string.hexdigits for t in thingy.partition("-")[-1]) diff --git a/gufe/tests/test_serialization_migration.py b/gufe/tests/test_serialization_migration.py index 3e664252..69d71d4f 100644 --- a/gufe/tests/test_serialization_migration.py +++ b/gufe/tests/test_serialization_migration.py @@ -1,47 +1,51 @@ -import pytest import copy +from typing import Any, Optional, Type +import pytest +from pydantic import BaseModel + +from gufe.tests.test_tokenization import GufeTokenizableTestsMixin from gufe.tokenization import ( GufeTokenizable, - new_key_added, - old_key_removed, - key_renamed, - nested_key_moved, - from_dict, _label_to_parts, _pop_nested, _set_nested, + from_dict, + key_renamed, + nested_key_moved, + new_key_added, + old_key_removed, ) -from gufe.tests.test_tokenization import GufeTokenizableTestsMixin -from pydantic import BaseModel - -from typing import Optional, Any, Type - @pytest.fixture def nested_data(): - return { - "foo": {"foo2" : [{"foo3": "foo4"}, "foo5"]}, - "bar": ["bar2", "bar3"] - } - -@pytest.mark.parametrize('label, expected', [ - ("foo", ["foo"]), - ("foo.foo2", ["foo", "foo2"]), - ("foo.foo2[0]", ["foo", "foo2", 0]), - ("foo.foo2[0].foo3", ["foo", "foo2", 0, "foo3"]), -]) + return {"foo": {"foo2": [{"foo3": "foo4"}, "foo5"]}, "bar": ["bar2", "bar3"]} + + +@pytest.mark.parametrize( + "label, expected", + [ + ("foo", ["foo"]), + ("foo.foo2", ["foo", "foo2"]), + ("foo.foo2[0]", ["foo", "foo2", 0]), + ("foo.foo2[0].foo3", ["foo", "foo2", 0, "foo3"]), + ], +) def test_label_to_parts(label, expected): assert _label_to_parts(label) == expected -@pytest.mark.parametrize('label, popped, remaining', [ - ("foo", {"foo2" : [{"foo3": "foo4"}, "foo5"]}, {}), - ("foo.foo2", [{"foo3": "foo4"}, "foo5"], {"foo": {}}), - ("foo.foo2[0]", {"foo3": "foo4"}, {"foo": {"foo2": ["foo5"]}}), - ("foo.foo2[0].foo3", "foo4", {"foo": {"foo2": [{}, "foo5"]}}), - ("foo.foo2[1]", "foo5", {"foo": {"foo2": [{"foo3": "foo4"}]}}), -]) + +@pytest.mark.parametrize( + "label, popped, remaining", + [ + ("foo", {"foo2": [{"foo3": "foo4"}, "foo5"]}, {}), + ("foo.foo2", [{"foo3": "foo4"}, "foo5"], {"foo": {}}), + ("foo.foo2[0]", {"foo3": "foo4"}, {"foo": {"foo2": ["foo5"]}}), + ("foo.foo2[0].foo3", "foo4", {"foo": {"foo2": [{}, "foo5"]}}), + ("foo.foo2[1]", "foo5", {"foo": {"foo2": [{"foo3": "foo4"}]}}), + ], +) def test_pop_nested(nested_data, label, popped, remaining): val = _pop_nested(nested_data, label) expected_remaining = {"bar": ["bar2", "bar3"]} @@ -49,13 +53,17 @@ def test_pop_nested(nested_data, label, popped, remaining): assert val == popped assert nested_data == expected_remaining -@pytest.mark.parametrize("label, expected_foo", [ - ("foo", {"foo": 10}), - ("foo.foo2", {"foo": {"foo2": 10}}), - ("foo.foo2[0]", {"foo": {"foo2": [10, "foo5"]}}), - ("foo.foo2[0].foo3", {"foo": {"foo2": [{"foo3": 10}, "foo5"]}}), - ("foo.foo2[1]", {"foo": {"foo2": [{"foo3": "foo4"}, 10]}}), -]) + +@pytest.mark.parametrize( + "label, expected_foo", + [ + ("foo", {"foo": 10}), + ("foo.foo2", {"foo": {"foo2": 10}}), + ("foo.foo2[0]", {"foo": {"foo2": [10, "foo5"]}}), + ("foo.foo2[0].foo3", {"foo": {"foo2": [{"foo3": 10}, "foo5"]}}), + ("foo.foo2[1]", {"foo": {"foo2": [{"foo3": "foo4"}, 10]}}), + ], +) def test_set_nested(nested_data, label, expected_foo): _set_nested(nested_data, label, 10) expected = {"bar": ["bar2", "bar3"]} @@ -65,6 +73,7 @@ def test_set_nested(nested_data, label, expected_foo): class _DefaultBase(GufeTokenizable): """Convenience class to avoid rewriting these methods""" + @classmethod def _from_dict(cls, dct): return cls(**dct) @@ -80,16 +89,17 @@ def _schema_version(cls): # this represents an "original" object with fields `foo` and `bar` _SERIALIZED_OLD = { - '__module__': None, # define in each test - '__qualname__': None, # define in each test - 'foo': "foo", - 'bar': "bar", - ':version:': 1, + "__module__": None, # define in each test + "__qualname__": None, # define in each test + "foo": "foo", + "bar": "bar", + ":version:": 1, } class KeyAdded(_DefaultBase): """Add key ``qux`` to the object's dict""" + def __init__(self, foo, bar, qux=10): self.foo = foo self.bar = bar @@ -98,7 +108,7 @@ def __init__(self, foo, bar, qux=10): @classmethod def serialization_migration(cls, dct, version): if version == 1: - dct = new_key_added(dct, 'qux', 10) + dct = new_key_added(dct, "qux", 10) return dct @@ -108,6 +118,7 @@ def _to_dict(self): class KeyRemoved(_DefaultBase): """Remove key ``bar`` from the object's dict""" + def __init__(self, foo): self.foo = foo @@ -124,6 +135,7 @@ def _to_dict(self): class KeyRenamed(_DefaultBase): """Rename key ``bar`` to ``baz`` in the object's dict""" + def __init__(self, foo, baz): self.foo = foo self.baz = baz @@ -153,17 +165,17 @@ def instance(self): def _prep_dct(self, dct): dct = copy.deepcopy(self.input_dict) - dct['__module__'] = self.cls.__module__ - dct['__qualname__'] = self.cls.__qualname__ + dct["__module__"] = self.cls.__module__ + dct["__qualname__"] = self.cls.__qualname__ return dct def test_serialization_migration(self): # in these examples, self.kwargs is the same as the output of # serialization_migration (not necessarily true for all classes) dct = self._prep_dct(self.input_dict) - del dct['__module__'] - del dct['__qualname__'] - version = dct.pop(':version:') + del dct["__module__"] + del dct["__qualname__"] + version = dct.pop(":version:") assert self.cls.serialization_migration(dct, version) == self.kwargs def test_migration(self, instance): @@ -172,6 +184,7 @@ def test_migration(self, instance): expected = instance assert expected == reconstructed + class TestKeyAdded(MigrationTester): cls = KeyAdded input_dict = _SERIALIZED_OLD @@ -196,12 +209,7 @@ class TestKeyRenamed(MigrationTester): "__module__": ..., "__qualname__": ..., ":version:": 1, - "settings": { - "son": { - "son_child": 10 - }, - "daughter": {} - } + "settings": {"son": {"son_child": 10}, "daughter": {}}, } @@ -211,6 +219,7 @@ class SonSettings(BaseModel): class DaughterSettings(BaseModel): """v2 model has child; v1 would not""" + daughter_child: int @@ -224,11 +233,11 @@ def __init__(self, settings: GrandparentSettings): self.settings = settings def _to_dict(self): - return {'settings': self.settings.dict()} + return {"settings": self.settings.dict()} @classmethod def _from_dict(cls, dct): - settings = GrandparentSettings.parse_obj(dct['settings']) + settings = GrandparentSettings.parse_obj(dct["settings"]) return cls(settings=settings) @classmethod @@ -241,7 +250,7 @@ def serialization_migration(cls, dct, version): dct = nested_key_moved( dct, old_name="settings.son.son_child", - new_name="settings.daughter.daughter_child" + new_name="settings.daughter.daughter_child", ) return dct @@ -250,13 +259,8 @@ def serialization_migration(cls, dct, version): class TestNestedKeyMoved(MigrationTester): cls = Grandparent input_dict = _SERIALIZED_NESTED_OLD - kwargs = { - 'settings': {'son': {}, 'daughter': {'daughter_child': 10}} - } + kwargs = {"settings": {"son": {}, "daughter": {"daughter_child": 10}}} @pytest.fixture def instance(self): - return self.cls(GrandparentSettings( - son=SonSettings(), - daughter=DaughterSettings(daughter_child=10) - )) + return self.cls(GrandparentSettings(son=SonSettings(), daughter=DaughterSettings(daughter_child=10))) diff --git a/gufe/tests/test_smallmoleculecomponent.py b/gufe/tests/test_smallmoleculecomponent.py index 354771c9..e224227d 100644 --- a/gufe/tests/test_smallmoleculecomponent.py +++ b/gufe/tests/test_smallmoleculecomponent.py @@ -3,6 +3,7 @@ import importlib import importlib.resources + try: import openff.toolkit.topology from openff.units import unit @@ -10,54 +11,58 @@ HAS_OFFTK = False else: HAS_OFFTK = True +import json import os from unittest import mock -import pytest -from gufe import SmallMoleculeComponent -from gufe.components.explicitmoleculecomponent import ( - _ensure_ofe_name, -) -import gufe -import json +import pytest from rdkit import Chem from rdkit.Chem import AllChem + +import gufe +from gufe import SmallMoleculeComponent +from gufe.components.explicitmoleculecomponent import _ensure_ofe_name from gufe.tokenization import TOKENIZABLE_REGISTRY from .test_tokenization import GufeTokenizableTestsMixin + @pytest.fixture def alt_ethane(): mol = Chem.AddHs(Chem.MolFromSmiles("CC")) Chem.AllChem.Compute2DCoords(mol) return SmallMoleculeComponent(mol) + @pytest.fixture def named_ethane(): mol = Chem.AddHs(Chem.MolFromSmiles("CC")) Chem.AllChem.Compute2DCoords(mol) - return SmallMoleculeComponent(mol, name='ethane') - - -@pytest.mark.parametrize('internal,rdkit_name,name,expected', [ - ('', 'foo', '', 'foo'), - ('', '', 'foo', 'foo'), - ('', 'bar', 'foo', 'foo'), - ('bar', '', 'foo', 'foo'), - ('baz', 'bar', 'foo', 'foo'), - ('foo', '', '', 'foo'), -]) + return SmallMoleculeComponent(mol, name="ethane") + + +@pytest.mark.parametrize( + "internal,rdkit_name,name,expected", + [ + ("", "foo", "", "foo"), + ("", "", "foo", "foo"), + ("", "bar", "foo", "foo"), + ("bar", "", "foo", "foo"), + ("baz", "bar", "foo", "foo"), + ("foo", "", "", "foo"), + ], +) def test_ensure_ofe_name(internal, rdkit_name, name, expected, recwarn): rdkit = Chem.AddHs(Chem.MolFromSmiles("CC")) if internal: - rdkit.SetProp('_Name', internal) + rdkit.SetProp("_Name", internal) if rdkit_name: - rdkit.SetProp('ofe-name', rdkit_name) + rdkit.SetProp("ofe-name", rdkit_name) out_name = _ensure_ofe_name(rdkit, name) - if {rdkit_name, internal} - {'foo', ''}: + if {rdkit_name, internal} - {"foo", ""}: # we should warn if rdkit properties are anything other than 'foo' # (expected) or the empty string (not set) assert len(recwarn) == 1 @@ -91,22 +96,22 @@ def test_warn_multiple_conformers(self): def test_rdkit_independence(self): # once we've constructed a Molecule, it is independent from the source - mol = Chem.MolFromSmiles('CC') + mol = Chem.MolFromSmiles("CC") AllChem.Compute2DCoords(mol) our_mol = SmallMoleculeComponent.from_rdkit(mol) - mol.SetProp('foo', 'bar') # this is the source molecule, not ours + mol.SetProp("foo", "bar") # this is the source molecule, not ours with pytest.raises(KeyError): - our_mol.to_rdkit().GetProp('foo') + our_mol.to_rdkit().GetProp("foo") def test_rdkit_copy_source_copy(self): # we should copy in any properties that were in the source molecule - mol = Chem.MolFromSmiles('CC') + mol = Chem.MolFromSmiles("CC") AllChem.Compute2DCoords(mol) - mol.SetProp('foo', 'bar') + mol.SetProp("foo", "bar") our_mol = SmallMoleculeComponent.from_rdkit(mol) - assert our_mol.to_rdkit().GetProp('foo') == 'bar' + assert our_mol.to_rdkit().GetProp("foo") == "bar" def test_equality_and_hash(self, ethane, alt_ethane): assert hash(ethane) == hash(alt_ethane) @@ -118,13 +123,13 @@ def test_equality_and_hash_name_differs(self, ethane, named_ethane): assert ethane != named_ethane def test_smiles(self, named_ethane): - assert named_ethane.smiles == 'CC' + assert named_ethane.smiles == "CC" def test_name(self, named_ethane): - assert named_ethane.name == 'ethane' + assert named_ethane.name == "ethane" def test_empty_name(self, alt_ethane): - assert alt_ethane.name == '' + assert alt_ethane.name == "" @pytest.mark.xfail def test_serialization_cycle(self, named_ethane): @@ -136,24 +141,23 @@ def test_serialization_cycle(self, named_ethane): assert serialized == reserialized def test_to_sdf_string(self, named_ethane, ethane_sdf): - with open(ethane_sdf, "r") as f: + with open(ethane_sdf) as f: expected = f.read() assert named_ethane.to_sdf() == expected @pytest.mark.xfail def test_from_sdf_string(self, named_ethane, ethane_sdf): - with open(ethane_sdf, "r") as f: + with open(ethane_sdf) as f: sdf_str = f.read() assert SmallMoleculeComponent.from_sdf_string(sdf_str) == named_ethane @pytest.mark.xfail - def test_from_sdf_file(self, named_ethane, ethane_sdf, - tmpdir): - with open(ethane_sdf, 'r') as f: + def test_from_sdf_file(self, named_ethane, ethane_sdf, tmpdir): + with open(ethane_sdf) as f: sdf_str = f.read() - with open(tmpdir / "temp.sdf", mode='w') as tmpf: + with open(tmpdir / "temp.sdf", mode="w") as tmpf: tmpf.write(sdf_str) assert SmallMoleculeComponent.from_sdf_file(tmpdir / "temp.sdf") == named_ethane @@ -163,7 +167,7 @@ def test_from_sdf_file_junk(self, toluene_mol2_path): SmallMoleculeComponent.from_sdf_file(toluene_mol2_path) def test_from_sdf_string_multiple_molecules(self, multi_molecule_sdf): - data = open(multi_molecule_sdf, 'r').read() + data = open(multi_molecule_sdf).read() with pytest.raises(RuntimeError, match="contains more than 1"): SmallMoleculeComponent.from_sdf_string(data) @@ -184,17 +188,20 @@ def test_serialization_cycle_smiles(self, named_ethane): assert named_ethane is not copy assert named_ethane.smiles == copy.smiles - @pytest.mark.parametrize('replace', ( - ['name'], - ['mol'], - ['name', 'mol'], - )) + @pytest.mark.parametrize( + "replace", + ( + ["name"], + ["mol"], + ["name", "mol"], + ), + ) def test_copy_with_replacements(self, named_ethane, replace): replacements = {} - if 'name' in replace: - replacements['name'] = "foo" + if "name" in replace: + replacements["name"] = "foo" - if 'mol' in replace: + if "mol" in replace: # it is a little weird to use copy_with_replacements to replace # the whole molecule (possibly keeping the same name), but it # should work if someone does! (could more easily imagine only @@ -203,16 +210,16 @@ def test_copy_with_replacements(self, named_ethane, replace): Chem.AllChem.Compute2DCoords(rdmol) mol = SmallMoleculeComponent.from_rdkit(rdmol) dct = mol._to_dict() - for item in ['atoms', 'bonds', 'conformer']: + for item in ["atoms", "bonds", "conformer"]: replacements[item] = dct[item] new = named_ethane.copy_with_replacements(**replacements) - if 'name' in replace: + if "name" in replace: assert new.name == "foo" else: assert new.name == "ethane" - if 'mol' in replace: + if "mol" in replace: assert new.smiles == "CO" else: assert new.smiles == "CC" @@ -228,15 +235,15 @@ def test_to_off(self, ethane): def test_to_off_name(self, named_ethane): off_ethane = named_ethane.to_openff() - assert off_ethane.name == 'ethane' + assert off_ethane.name == "ethane" @pytest.mark.skipif(not HAS_OFFTK, reason="no openff tookit available") class TestSmallMoleculeComponentPartialCharges: - @pytest.fixture(scope='function') + @pytest.fixture(scope="function") def charged_off_ethane(self, ethane): off_ethane = ethane.to_openff() - off_ethane.assign_partial_charges(partial_charge_method='am1bcc') + off_ethane.assign_partial_charges(partial_charge_method="am1bcc") return off_ethane def test_partial_charges_warning(self, charged_off_ethane): @@ -258,7 +265,7 @@ def test_partial_charges_not_formal_error(self, charged_off_ethane): def test_partial_charges_too_few_atoms(self): mol = Chem.AddHs(Chem.MolFromSmiles("CC")) Chem.AllChem.Compute2DCoords(mol) - mol.SetProp('atom.dprop.PartialCharge', '1') + mol.SetProp("atom.dprop.PartialCharge", "1") with pytest.raises(ValueError, match="Incorrect number of"): SmallMoleculeComponent.from_rdkit(mol) @@ -271,7 +278,7 @@ def test_partial_charges_applied_to_atoms(self): mol = Chem.AddHs(Chem.MolFromSmiles("C")) Chem.AllChem.Compute2DCoords(mol) # add some fake charges at the molecule level - mol.SetProp('atom.dprop.PartialCharge', '-1 0.25 0.25 0.25 0.25') + mol.SetProp("atom.dprop.PartialCharge", "-1 0.25 0.25 0.25 0.25") matchmsg = "Partial charges have been provided" with pytest.warns(UserWarning, match=matchmsg): ofe = SmallMoleculeComponent.from_rdkit(mol) @@ -292,22 +299,24 @@ def test_inconsistent_charges(self, charged_off_ethane): mol = Chem.AddHs(Chem.MolFromSmiles("C")) Chem.AllChem.Compute2DCoords(mol) # add some fake charges at the molecule level - mol.SetProp('atom.dprop.PartialCharge', '-1 0.25 0.25 0.25 0.25') + mol.SetProp("atom.dprop.PartialCharge", "-1 0.25 0.25 0.25 0.25") # set different charges to the atoms for atom in mol.GetAtoms(): atom.SetDoubleProp("PartialCharge", 0) # make sure the correct error is raised - msg = ("non-equivalent partial charges between " - "atom and molecule properties") + msg = "non-equivalent partial charges between " "atom and molecule properties" with pytest.raises(ValueError, match=msg): SmallMoleculeComponent.from_rdkit(mol) - -@pytest.mark.parametrize('mol, charge', [ - ('CC', 0), ('CC[O-]', -1), -]) +@pytest.mark.parametrize( + "mol, charge", + [ + ("CC", 0), + ("CC[O-]", -1), + ], +) def test_total_charge_neutral(mol, charge): mol = Chem.MolFromSmiles(mol) AllChem.Compute2DCoords(mol) @@ -330,6 +339,36 @@ def test_to_dict(self, phenol): assert isinstance(d, dict) + def test_to_dict_hybridization(self, phenol): + """ + Make sure dict round trip saves the hybridization + + """ + phenol_dict = phenol.to_dict() + TOKENIZABLE_REGISTRY.clear() + new_phenol = SmallMoleculeComponent.from_dict(phenol_dict) + for atom in new_phenol.to_rdkit().GetAtoms(): + if atom.GetAtomicNum() == 6: + assert atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2 + + def test_from_dict_missing_hybridization(self, phenol): + """ + For backwards compatibility make sure we can create an SMC with missing hybridization info. + """ + phenol_dict = phenol.to_dict() + new_atoms = [] + for atom in phenol_dict["atoms"]: + # remove the hybridization atomic info which should be at index 7 + new_atoms.append(tuple([atom_info for i, atom_info in enumerate(atom) if i != 7])) + phenol_dict["atoms"] = new_atoms + with pytest.warns(match="The atom hybridization data was not found and has been set to unspecified."): + new_phenol = SmallMoleculeComponent.from_dict(phenol_dict) + # they should be different objects due to the missing hybridization info + assert new_phenol != phenol + # make sure the rdkit objects are different + for atom_hybrid, atom_no_hybrid in zip(phenol.to_rdkit().GetAtoms(), new_phenol.to_rdkit().GetAtoms()): + assert atom_hybrid.GetHybridization() != atom_no_hybrid.GetHybridization() + @pytest.mark.skipif(not HAS_OFFTK, reason="no openff toolkit available") def test_deserialize_roundtrip(self, toluene, phenol): roundtrip = SmallMoleculeComponent.from_dict(phenol.to_dict()) @@ -347,11 +386,11 @@ def test_deserialize_roundtrip(self, toluene, phenol): @pytest.mark.xfail def test_bounce_off_file(self, toluene, tmpdir): - fname = str(tmpdir / 'mol.json') + fname = str(tmpdir / "mol.json") - with open(fname, 'w') as f: + with open(fname, "w") as f: f.write(toluene.to_json()) - with open(fname, 'r') as f: + with open(fname) as f: d = json.load(f) assert isinstance(d, dict) @@ -373,29 +412,29 @@ def test_to_openff_after_serialisation(self, toluene): assert off1 == off2 -@pytest.mark.parametrize('target', ['atom', 'bond', 'conformer', 'mol']) -@pytest.mark.parametrize('dtype', ['int', 'bool', 'str', 'float']) +@pytest.mark.parametrize("target", ["atom", "bond", "conformer", "mol"]) +@pytest.mark.parametrize("dtype", ["int", "bool", "str", "float"]) def test_prop_preservation(ethane, target, dtype): # issue 145 make sure props are propagated mol = Chem.MolFromSmiles("CC") Chem.AllChem.Compute2DCoords(mol) - if target == 'atom': + if target == "atom": obj = mol.GetAtomWithIdx(0) - elif target == 'bond': + elif target == "bond": obj = mol.GetBondWithIdx(0) - elif target == 'conformer': + elif target == "conformer": obj = mol.GetConformer() else: obj = mol - if dtype == 'int': - obj.SetIntProp('foo', 1234) - elif dtype == 'bool': - obj.SetBoolProp('foo', False) - elif dtype == 'str': - obj.SetProp('foo', 'bar') - elif dtype == 'float': - obj.SetDoubleProp('foo', 1.234) + if dtype == "int": + obj.SetIntProp("foo", 1234) + elif dtype == "bool": + obj.SetBoolProp("foo", False) + elif dtype == "str": + obj.SetProp("foo", "bar") + elif dtype == "float": + obj.SetDoubleProp("foo", 1.234) else: pytest.fail() @@ -403,29 +442,29 @@ def test_prop_preservation(ethane, target, dtype): d = SmallMoleculeComponent(rdkit=mol).to_dict() e2 = SmallMoleculeComponent.from_dict(d).to_rdkit() - if target == 'atom': + if target == "atom": obj = e2.GetAtomWithIdx(0) - elif target == 'bond': + elif target == "bond": obj = e2.GetBondWithIdx(0) - elif target == 'conformer': + elif target == "conformer": obj = e2.GetConformer() else: obj = e2 - if dtype == 'int': - assert obj.GetIntProp('foo') == 1234 - elif dtype == 'bool': - assert obj.GetBoolProp('foo') is False - elif dtype == 'str': - assert obj.GetProp('foo') == 'bar' + if dtype == "int": + assert obj.GetIntProp("foo") == 1234 + elif dtype == "bool": + assert obj.GetBoolProp("foo") is False + elif dtype == "str": + assert obj.GetProp("foo") == "bar" else: - assert obj.GetDoubleProp('foo') == pytest.approx(1.234) + assert obj.GetDoubleProp("foo") == pytest.approx(1.234) def test_missing_H_warning(): - m = Chem.MolFromSmiles('CC') + m = Chem.MolFromSmiles("CC") Chem.AllChem.Compute2DCoords(m) - with pytest.warns(UserWarning, match='removeHs=False'): + with pytest.warns(UserWarning, match="removeHs=False"): _ = SmallMoleculeComponent(rdkit=m) diff --git a/gufe/tests/test_solvents.py b/gufe/tests/test_solvents.py index 88b2a87b..e8530777 100644 --- a/gufe/tests/test_solvents.py +++ b/gufe/tests/test_solvents.py @@ -1,7 +1,7 @@ import pytest +from openff.units import unit from gufe import SolventComponent -from openff.units import unit from .test_tokenization import GufeTokenizableTestsMixin @@ -9,70 +9,77 @@ def test_defaults(): s = SolventComponent() - assert s.smiles == 'O' + assert s.smiles == "O" assert s.positive_ion == "Na+" assert s.negative_ion == "Cl-" assert s.ion_concentration == 0.15 * unit.molar -@pytest.mark.parametrize('pos, neg', [ - # test: charge dropping, case sensitivity - ('Na', 'Cl'), ('Na+', 'Cl-'), ('na', 'cl'), -]) +@pytest.mark.parametrize( + "pos, neg", + [ + # test: charge dropping, case sensitivity + ("Na", "Cl"), + ("Na+", "Cl-"), + ("na", "cl"), + ], +) def test_hash(pos, neg): - s1 = SolventComponent(positive_ion='Na', negative_ion='Cl') + s1 = SolventComponent(positive_ion="Na", negative_ion="Cl") s2 = SolventComponent(positive_ion=pos, negative_ion=neg) assert s1 == s2 assert hash(s1) == hash(s2) - assert s2.positive_ion == 'Na+' - assert s2.negative_ion == 'Cl-' + assert s2.positive_ion == "Na+" + assert s2.negative_ion == "Cl-" def test_neq(): - s1 = SolventComponent(positive_ion='Na', negative_ion='Cl') - s2 = SolventComponent(positive_ion='K', negative_ion='Cl') + s1 = SolventComponent(positive_ion="Na", negative_ion="Cl") + s2 = SolventComponent(positive_ion="K", negative_ion="Cl") assert s1 != s2 -@pytest.mark.parametrize('conc', [0.0 * unit.molar, 1.75 * unit.molar]) +@pytest.mark.parametrize("conc", [0.0 * unit.molar, 1.75 * unit.molar]) def test_from_dict(conc): - s1 = SolventComponent(positive_ion='Na', negative_ion='Cl', - ion_concentration=conc, - neutralize=False) + s1 = SolventComponent(positive_ion="Na", negative_ion="Cl", ion_concentration=conc, neutralize=False) assert SolventComponent.from_dict(s1.to_dict()) == s1 def test_conc(): - s = SolventComponent(positive_ion='Na', negative_ion='Cl', - ion_concentration=1.75 * unit.molar) + s = SolventComponent(positive_ion="Na", negative_ion="Cl", ion_concentration=1.75 * unit.molar) - assert s.ion_concentration == unit.Quantity('1.75 M') + assert s.ion_concentration == unit.Quantity("1.75 M") -@pytest.mark.parametrize('conc,', - [1.22, # no units, 1.22 what? - 1.5 * unit.kg, # probably a tad much salt - -0.1 * unit.molar]) # negative conc +@pytest.mark.parametrize( + "conc,", + [ + 1.22, # no units, 1.22 what? + 1.5 * unit.kg, # probably a tad much salt + -0.1 * unit.molar, + ], +) # negative conc def test_bad_conc(conc): with pytest.raises(ValueError): - _ = SolventComponent(positive_ion='Na', negative_ion='Cl', - ion_concentration=conc) + _ = SolventComponent(positive_ion="Na", negative_ion="Cl", ion_concentration=conc) def test_solvent_charge(): - s = SolventComponent(positive_ion='Na', negative_ion='Cl', - ion_concentration=1.75 * unit.molar) + s = SolventComponent(positive_ion="Na", negative_ion="Cl", ion_concentration=1.75 * unit.molar) assert s.total_charge is None -@pytest.mark.parametrize('pos, neg,', [ - ('Na', 'C'), - ('F', 'I'), -]) +@pytest.mark.parametrize( + "pos, neg,", + [ + ("Na", "C"), + ("F", "I"), + ], +) def test_bad_inputs(pos, neg): with pytest.raises(ValueError): _ = SolventComponent(positive_ion=pos, negative_ion=neg) @@ -85,4 +92,4 @@ class TestSolventComponent(GufeTokenizableTestsMixin): @pytest.fixture def instance(self): - return SolventComponent(positive_ion='Na', negative_ion='Cl') + return SolventComponent(positive_ion="Na", negative_ion="Cl") diff --git a/gufe/tests/test_tokenization.py b/gufe/tests/test_tokenization.py index f7b0744d..fc2b49ce 100644 --- a/gufe/tests/test_tokenization.py +++ b/gufe/tests/test_tokenization.py @@ -1,17 +1,26 @@ -import pytest import abc import datetime -import logging import io -from unittest import mock import json +import logging from typing import Optional +from unittest import mock + +import pytest from gufe.tokenization import ( - GufeTokenizable, GufeKey, tokenize, TOKENIZABLE_REGISTRY, - import_qualname, get_class, TOKENIZABLE_CLASS_REGISTRY, JSON_HANDLER, - get_all_gufe_objs, gufe_to_digraph, gufe_objects_from_shallow_dict, + JSON_HANDLER, + TOKENIZABLE_CLASS_REGISTRY, + TOKENIZABLE_REGISTRY, + GufeKey, + GufeTokenizable, KeyedChain, + get_all_gufe_objs, + get_class, + gufe_objects_from_shallow_dict, + gufe_to_digraph, + import_qualname, + tokenize, ) @@ -65,7 +74,7 @@ def __init__(self, obj, lst, dct): self.dct = dct def _to_dict(self): - return {'obj': self.obj, 'lst': self.lst, 'dct': self.dct} + return {"obj": self.obj, "lst": self.lst, "dct": self.dct} @classmethod def _from_dict(cls, dct): @@ -87,9 +96,7 @@ class GufeTokenizableTestsMixin(abc.ABC): @pytest.fixture def instance(self): - """Define instance to test with here. - - """ + """Define instance to test with here.""" ... def test_to_dict_roundtrip(self, instance): @@ -102,7 +109,7 @@ def test_to_dict_roundtrip(self, instance): # not generally true that the dict forms are equal, e.g. if they # include `np.nan`s - #assert ser == reser + # assert ser == reser @pytest.mark.skip def test_to_dict_roundtrip_clear_registry(self, instance): @@ -125,7 +132,7 @@ def test_to_keyed_dict_roundtrip(self, instance): # not generally true that the dict forms are equal, e.g. if they # include `np.nan`s - #assert ser == reser + # assert ser == reser def test_to_shallow_dict_roundtrip(self, instance): ser = instance.to_shallow_dict() @@ -137,7 +144,7 @@ def test_to_shallow_dict_roundtrip(self, instance): # not generally true that the dict forms are equal, e.g. if they # include `np.nan`s - #assert ser == reser + # assert ser == reser def test_key_stable(self, instance): """Check that generating the instance from a dict representation yields @@ -164,9 +171,7 @@ class TestGufeTokenizable(GufeTokenizableTestsMixin): @pytest.fixture def instance(self): - """Define instance to test with here. - - """ + """Define instance to test with here.""" return self.cont def setup_method(self): @@ -176,49 +181,55 @@ def setup_method(self): self.cont = Container(bar, [leaf, 0], {"leaf": leaf, "a": "b"}) def leaf_dict(a): - return {'__module__': __name__, '__qualname__': "Leaf", "a": a, - "b": 2, ':version:': 1} + return { + "__module__": __name__, + "__qualname__": "Leaf", + "a": a, + "b": 2, + ":version:": 1, + } self.expected_deep = { - '__qualname__': "Container", - '__module__': __name__, - 'obj': leaf_dict(leaf_dict("foo")), - 'lst': [leaf_dict("foo"), 0], - 'dct': {"leaf": leaf_dict("foo"), "a": "b"}, - ':version:': 1, + "__qualname__": "Container", + "__module__": __name__, + "obj": leaf_dict(leaf_dict("foo")), + "lst": [leaf_dict("foo"), 0], + "dct": {"leaf": leaf_dict("foo"), "a": "b"}, + ":version:": 1, } self.expected_shallow = { - '__qualname__': "Container", - '__module__': __name__, - 'obj': bar, - 'lst': [leaf, 0], - 'dct': {'leaf': leaf, 'a': 'b'}, - ':version:': 1, + "__qualname__": "Container", + "__module__": __name__, + "obj": bar, + "lst": [leaf, 0], + "dct": {"leaf": leaf, "a": "b"}, + ":version:": 1, } self.expected_keyed = { - '__qualname__': "Container", - '__module__': __name__, - 'obj': {":gufe-key:": bar.key}, - 'lst': [{":gufe-key:": leaf.key}, 0], - 'dct': {'leaf': {":gufe-key:": leaf.key}, 'a': 'b'}, - ':version:': 1, + "__qualname__": "Container", + "__module__": __name__, + "obj": {":gufe-key:": bar.key}, + "lst": [{":gufe-key:": leaf.key}, 0], + "dct": {"leaf": {":gufe-key:": leaf.key}, "a": "b"}, + ":version:": 1, } self.expected_keyed_chain = [ - (str(leaf.key), - leaf_dict("foo")), - (str(bar.key), - leaf_dict({':gufe-key:': str(leaf.key)})), - (str(self.cont.key), - {':version:': 1, - '__module__': __name__, - '__qualname__': 'Container', - 'dct': {'a': 'b', - 'leaf': {':gufe-key:': str(leaf.key)}}, - 'lst': [{':gufe-key:': str(leaf.key)}, 0], - 'obj': {':gufe-key:': str(bar.key)}}) + (str(leaf.key), leaf_dict("foo")), + (str(bar.key), leaf_dict({":gufe-key:": str(leaf.key)})), + ( + str(self.cont.key), + { + ":version:": 1, + "__module__": __name__, + "__qualname__": "Container", + "dct": {"a": "b", "leaf": {":gufe-key:": str(leaf.key)}}, + "lst": [{":gufe-key:": str(leaf.key)}, 0], + "obj": {":gufe-key:": str(bar.key)}, + }, + ), ] def test_set_key(self): @@ -283,7 +294,11 @@ def test_to_json_file(self, tmpdir): def test_from_json_file(self, tmpdir): file_path = tmpdir / "container.json" - json.dump(self.expected_keyed_chain, file_path.open(mode="w"), cls=JSON_HANDLER.encoder) + json.dump( + self.expected_keyed_chain, + file_path.open(mode="w"), + cls=JSON_HANDLER.encoder, + ) recreated = self.cls.from_json(file=file_path) assert recreated == self.cont @@ -299,7 +314,7 @@ def test_from_shallow_dict(self): # here we keep the same objects in memory assert recreated.obj.a is recreated.lst[0] - assert recreated.obj.a is recreated.dct['leaf'] + assert recreated.obj.a is recreated.dct["leaf"] def test_notequal_different_type(self): l1 = Leaf(4) @@ -326,13 +341,11 @@ def test_copy_with_replacements_invalid(self): with pytest.raises(TypeError, match="Invalid"): _ = l1.copy_with_replacements(foo=10) - @pytest.mark.parametrize('level', ["DEBUG", "INFO", "CRITICAL"]) + @pytest.mark.parametrize("level", ["DEBUG", "INFO", "CRITICAL"]) def test_logging(self, level): stream = io.StringIO() handler = logging.StreamHandler(stream) - fmt = logging.Formatter( - "%(name)s - %(gufekey)s - %(levelname)s - %(message)s" - ) + fmt = logging.Formatter("%(name)s - %(gufekey)s - %(levelname)s - %(message)s") name = "gufekey.gufe.tests.test_tokenization.Leaf" logger = logging.getLogger(name) logger.setLevel(getattr(logging, level)) @@ -342,7 +355,7 @@ def test_logging(self, level): leaf = Leaf(10) results = stream.getvalue() - key = leaf.key.split('-')[-1] + key = leaf.key.split("-")[-1] initial_log = f"{name} - UNKNOWN - INFO - no key defined!\n" info_log = f"{name} - {key} - INFO - a=10\n" @@ -370,11 +383,14 @@ class Inner: pass -@pytest.mark.parametrize('modname, qualname, expected', [ - (__name__, "Outer", Outer), - (__name__, "Outer.Inner", Outer.Inner), - ("gufe.tokenization", 'import_qualname', import_qualname), -]) +@pytest.mark.parametrize( + "modname, qualname, expected", + [ + (__name__, "Outer", Outer), + (__name__, "Outer.Inner", Outer.Inner), + ("gufe.tokenization", "import_qualname", import_qualname), + ], +) def test_import_qualname(modname, qualname, expected): assert import_qualname(modname, qualname) is expected @@ -382,9 +398,9 @@ def test_import_qualname(modname, qualname, expected): def test_import_qualname_not_yet_imported(): # this is specifically to test that something we don't have imported in # this module will import correctly - msg_cls = import_qualname(modname="email.message", - qualname="EmailMessage") + msg_cls = import_qualname(modname="email.message", qualname="EmailMessage") from email.message import EmailMessage + assert msg_cls is EmailMessage @@ -393,26 +409,33 @@ def test_import_qualname_remappings(): assert import_qualname("foo", "Bar.Baz", remappings) is Outer.Inner -@pytest.mark.parametrize('modname, qualname', [ - (None, "Outer.Inner"), - (__name__, None), -]) +@pytest.mark.parametrize( + "modname, qualname", + [ + (None, "Outer.Inner"), + (__name__, None), + ], +) def test_import_qualname_error_none(modname, qualname): with pytest.raises(ValueError, match="cannot be None"): import_qualname(modname, qualname) - -@pytest.mark.parametrize('cls_reg', [ - {}, - {(__name__, "Outer.Inner"): Outer.Inner}, -]) +@pytest.mark.parametrize( + "cls_reg", + [ + {}, + {(__name__, "Outer.Inner"): Outer.Inner}, + ], +) def test_get_class(cls_reg): with mock.patch.dict("gufe.tokenization.TOKENIZABLE_CLASS_REGISTRY", cls_reg): assert get_class(__name__, "Outer.Inner") is Outer.Inner + def test_path_to_json(): import pathlib + p = pathlib.Path("foo/bar") ser = json.dumps(p, cls=JSON_HANDLER.encoder) deser = json.loads(ser, cls=JSON_HANDLER.decoder) @@ -423,27 +446,25 @@ def test_path_to_json(): class TestGufeKey: def test_to_dict(self): - k = GufeKey('foo-bar') + k = GufeKey("foo-bar") - assert k.to_dict() == {':gufe-key:': 'foo-bar'} + assert k.to_dict() == {":gufe-key:": "foo-bar"} def test_prefix(self): - k = GufeKey('foo-bar') + k = GufeKey("foo-bar") - assert k.prefix == 'foo' + assert k.prefix == "foo" def test_token(self): - k = GufeKey('foo-bar') + k = GufeKey("foo-bar") - assert k.token == 'bar' + assert k.token == "bar" def test_gufe_to_digraph(solvated_complex): graph = gufe_to_digraph(solvated_complex) - connected_objects = gufe_objects_from_shallow_dict( - solvated_complex.to_shallow_dict() - ) + connected_objects = gufe_objects_from_shallow_dict(solvated_complex.to_shallow_dict()) assert len(graph.nodes) == 4 assert len(graph.edges) == 3 @@ -472,9 +493,7 @@ def test_from_gufe(self, benzene_variants_star_map): assert len(kc) == expected_len original_keys = [obj.key for obj in contained_objects] - original_keyed_dicts = [ - obj.to_keyed_dict() for obj in contained_objects - ] + original_keyed_dicts = [obj.to_keyed_dict() for obj in contained_objects] kc_gufe_keys = set(kc.gufe_keys()) kc_keyed_dicts = list(kc.keyed_dicts()) @@ -498,7 +517,7 @@ def test_get_item(self, benzene_variants_star_map): def test_datetime_to_json(): - d = datetime.datetime.fromisoformat('2023-05-05T09:06:43.699068') + d = datetime.datetime.fromisoformat("2023-05-05T09:06:43.699068") ser = json.dumps(d, cls=JSON_HANDLER.encoder) diff --git a/gufe/tests/test_transformation.py b/gufe/tests/test_transformation.py index 3ad29c59..445ff651 100644 --- a/gufe/tests/test_transformation.py +++ b/gufe/tests/test_transformation.py @@ -1,13 +1,14 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import pytest import io import pathlib +import pytest + import gufe -from gufe.transformations import Transformation, NonTransformation from gufe.protocols.protocoldag import execute_DAG +from gufe.transformations import NonTransformation, Transformation from .test_protocol import DummyProtocol, DummyProtocolResult from .test_tokenization import GufeTokenizableTestsMixin @@ -25,8 +26,10 @@ def absolute_transformation(solvated_ligand, solvated_complex): @pytest.fixture def complex_equilibrium(solvated_complex): - return NonTransformation(solvated_complex, - protocol=DummyProtocol(settings=DummyProtocol.default_settings())) + return NonTransformation( + solvated_complex, + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + ) class TestTransformation(GufeTokenizableTestsMixin): @@ -51,12 +54,12 @@ def test_protocol(self, absolute_transformation, tmpdir): protocoldag = tnf.create() with tmpdir.as_cwd(): - shared = pathlib.Path('shared') + shared = pathlib.Path("shared") shared.mkdir(parents=True) - scratch = pathlib.Path('scratch') + scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) - + protocoldagresult = execute_DAG(protocoldag, shared_basedir=shared, scratch_basedir=scratch) protocolresult = tnf.gather([protocoldagresult]) @@ -64,8 +67,8 @@ def test_protocol(self, absolute_transformation, tmpdir): assert isinstance(protocolresult, DummyProtocolResult) assert len(protocolresult.data) == 2 - assert 'logs' in protocolresult.data - assert 'key_results' in protocolresult.data + assert "logs" in protocolresult.data + assert "key_results" in protocolresult.data def test_protocol_extend(self, absolute_transformation, tmpdir): tnf = absolute_transformation @@ -73,12 +76,12 @@ def test_protocol_extend(self, absolute_transformation, tmpdir): assert isinstance(tnf.protocol, DummyProtocol) with tmpdir.as_cwd(): - shared = pathlib.Path('shared') + shared = pathlib.Path("shared") shared.mkdir(parents=True) - scratch = pathlib.Path('scratch') + scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) - + protocoldag = tnf.create() protocoldagresult = execute_DAG(protocoldag, shared_basedir=shared, scratch_basedir=scratch) @@ -94,8 +97,9 @@ def test_protocol_extend(self, absolute_transformation, tmpdir): def test_equality(self, absolute_transformation, solvated_ligand, solvated_complex): opposite = Transformation( - solvated_complex, solvated_ligand, - protocol=DummyProtocol(settings=DummyProtocol.default_settings()) + solvated_complex, + solvated_ligand, + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), ) assert absolute_transformation != opposite @@ -124,16 +128,16 @@ def test_dump_load_roundtrip(self, absolute_transformation): assert absolute_transformation == recreated def test_deprecation_warning_on_dict_mapping(self, solvated_ligand, solvated_complex): - lig = solvated_complex.components['ligand'] + lig = solvated_complex.components["ligand"] # this mapping makes no sense, but it'll trigger the dep warning we want mapping = gufe.LigandAtomMapping(lig, lig, componentA_to_componentB={}) - with pytest.warns(DeprecationWarning, - match="mapping input as a dict is deprecated"): + with pytest.warns(DeprecationWarning, match="mapping input as a dict is deprecated"): Transformation( - solvated_complex, solvated_ligand, + solvated_complex, + solvated_ligand, protocol=DummyProtocol(settings=DummyProtocol.default_settings()), - mapping={'ligand': mapping}, + mapping={"ligand": mapping}, ) @@ -160,10 +164,10 @@ def test_protocol(self, complex_equilibrium, tmpdir): protocoldag = ntnf.create() with tmpdir.as_cwd(): - shared = pathlib.Path('shared') + shared = pathlib.Path("shared") shared.mkdir(parents=True) - scratch = pathlib.Path('scratch') + scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) protocoldagresult = execute_DAG(protocoldag, shared_basedir=shared, scratch_basedir=scratch) @@ -173,8 +177,8 @@ def test_protocol(self, complex_equilibrium, tmpdir): assert isinstance(protocolresult, DummyProtocolResult) assert len(protocolresult.data) == 2 - assert 'logs' in protocolresult.data - assert 'key_results' in protocolresult.data + assert "logs" in protocolresult.data + assert "key_results" in protocolresult.data def test_protocol_extend(self, complex_equilibrium, tmpdir): ntnf = complex_equilibrium @@ -182,10 +186,10 @@ def test_protocol_extend(self, complex_equilibrium, tmpdir): assert isinstance(ntnf.protocol, DummyProtocol) with tmpdir.as_cwd(): - shared = pathlib.Path('shared') + shared = pathlib.Path("shared") shared.mkdir(parents=True) - scratch = pathlib.Path('scratch') + scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) protocoldag = ntnf.create() @@ -203,17 +207,17 @@ def test_protocol_extend(self, complex_equilibrium, tmpdir): def test_equality(self, complex_equilibrium, solvated_ligand, solvated_complex): s = DummyProtocol.default_settings() s.n_repeats = 4031 - different_protocol_settings = NonTransformation( - solvated_complex, protocol=DummyProtocol(settings=s) - ) + different_protocol_settings = NonTransformation(solvated_complex, protocol=DummyProtocol(settings=s)) assert complex_equilibrium != different_protocol_settings identical = NonTransformation( - solvated_complex, protocol=DummyProtocol(settings=DummyProtocol.default_settings()) + solvated_complex, + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), ) assert complex_equilibrium == identical different_system = NonTransformation( - solvated_ligand, protocol=DummyProtocol(settings=DummyProtocol.default_settings()) + solvated_ligand, + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), ) assert complex_equilibrium != different_system diff --git a/gufe/tests/test_utils.py b/gufe/tests/test_utils.py index 8bd83749..1e2ec962 100644 --- a/gufe/tests/test_utils.py +++ b/gufe/tests/test_utils.py @@ -1,30 +1,28 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import pytest - import io import pathlib +import pytest + from gufe.utils import ensure_filelike -@pytest.mark.parametrize('input_type', [ - "str", "path", "TextIO", "BytesIO", "StringIO" -]) +@pytest.mark.parametrize("input_type", ["str", "path", "TextIO", "BytesIO", "StringIO"]) def test_ensure_filelike(input_type, tmp_path): path = tmp_path / "foo.txt" # we choose to use bytes for pathlib.Path just to mix things up; # string filename or path can be either bytes or string, so we give one # to each - use_bytes = input_type in {'path', 'BytesIO'} - filelike = input_type not in {'str', 'path'} + use_bytes = input_type in {"path", "BytesIO"} + filelike = input_type not in {"str", "path"} dumper = { - 'str': str(path), - 'path': path, - 'TextIO': open(path, mode='w'), - 'BytesIO': open(path, mode='wb'), - 'StringIO': io.StringIO(), + "str": str(path), + "path": path, + "TextIO": open(path, mode="w"), + "BytesIO": open(path, mode="wb"), + "StringIO": io.StringIO(), }[input_type] if filelike: @@ -40,15 +38,15 @@ def test_ensure_filelike(input_type, tmp_path): write_f.write(written) write_f.flush() - if input_type == 'StringIO': + if input_type == "StringIO": dumper.seek(0) loader = { - 'str': str(path), - 'path': path, - 'TextIO': open(path, mode='r'), - 'BytesIO': open(path, mode='rb'), - 'StringIO': dumper, + "str": str(path), + "path": path, + "TextIO": open(path), + "BytesIO": open(path, mode="rb"), + "StringIO": dumper, }[input_type] with ensure_filelike(loader, mode=read_mode) as read_f: @@ -64,13 +62,14 @@ def test_ensure_filelike(input_type, tmp_path): write_f.close() read_f.close() + @pytest.mark.parametrize("input_type", ["TextIO", "BytesIO", "StringIO"]) def test_ensure_filelike_force_close(input_type, tmp_path): path = tmp_path / "foo.txt" dumper = { - 'TextIO': open(path, mode='w'), - 'BytesIO': open(path, mode='wb'), - 'StringIO': io.StringIO(), + "TextIO": open(path, mode="w"), + "BytesIO": open(path, mode="wb"), + "StringIO": io.StringIO(), }[input_type] written = b"foo" if input_type == "BytesIO" else "foo" @@ -79,22 +78,23 @@ def test_ensure_filelike_force_close(input_type, tmp_path): assert f.closed + @pytest.mark.parametrize("input_type", ["TextIO", "BytesIO", "StringIO"]) def test_ensure_filelike_mode_warning(input_type, tmp_path): path = tmp_path / "foo.txt" dumper = { - 'TextIO': open(path, mode='w'), - 'BytesIO': open(path, mode='wb'), - 'StringIO': io.StringIO(), + "TextIO": open(path, mode="w"), + "BytesIO": open(path, mode="wb"), + "StringIO": io.StringIO(), }[input_type] - with pytest.warns(UserWarning, - match="User-specified mode will be ignored"): + with pytest.warns(UserWarning, match="User-specified mode will be ignored"): _ = ensure_filelike(dumper, mode="w") dumper.close() + def test_ensure_filelike_default_mode(): path = "foo.txt" loader = ensure_filelike(path) - assert loader.mode == 'r' + assert loader.mode == "r" diff --git a/gufe/tokenization.py b/gufe/tokenization.py index e7c959db..9ace1bf0 100644 --- a/gufe/tokenization.py +++ b/gufe/tokenization.py @@ -9,13 +9,15 @@ import inspect import json import logging -import networkx as nx import re import warnings import weakref +from collections.abc import Generator from itertools import chain from os import PathLike -from typing import Any, Union, List, Tuple, Dict, Generator, TextIO, Optional +from typing import Any, Dict, List, Optional, TextIO, Tuple, Union + +import networkx as nx from typing_extensions import Self from gufe.custom_codecs import ( @@ -115,11 +117,12 @@ class _GufeLoggerAdapter(logging.LoggerAdapter): extra: :class:`.GufeTokenizable` the instance this adapter is associated with """ + def process(self, msg, kwargs): - extra = kwargs.get('extra', {}) - if (extra_dict := getattr(self, '_extra_dict', None)) is None: + extra = kwargs.get("extra", {}) + if (extra_dict := getattr(self, "_extra_dict", None)) is None: try: - gufekey = self.extra.key.split('-')[-1] + gufekey = self.extra.key.split("-")[-1] except Exception: # no matter what happened, we have a bad key gufekey = "UNKNOWN" @@ -127,15 +130,13 @@ def process(self, msg, kwargs): else: save_extra_dict = True - extra_dict = { - 'gufekey': gufekey - } + extra_dict = {"gufekey": gufekey} if save_extra_dict: self._extra_dict = extra_dict extra.update(extra_dict) - kwargs['extra'] = extra + kwargs["extra"] = extra return msg, kwargs @@ -193,8 +194,9 @@ def old_key_removed(dct, old_key, should_warn): if should_warn: # TODO: this should be put elsewhere so that the warning can be more # meaningful (somewhere that knows what class we're recreating) - warnings.warn(f"Outdated serialization: '{old_key}', with value " - f"'{dct[old_key]}' is no longer used in this object") + warnings.warn( + f"Outdated serialization: '{old_key}', with value " f"'{dct[old_key]}' is no longer used in this object" + ) del dct[old_key] return dct @@ -231,6 +233,7 @@ def _label_to_parts(label): See :func:`.nested_key_moved` for a description of the label. """ + def _intify_if_possible(part): try: part = int(part) @@ -238,10 +241,8 @@ def _intify_if_possible(part): pass return part - parts = [ - _intify_if_possible(p) for p in re.split('\.|\[|\]', label) - if p != "" - ] + + parts = [_intify_if_possible(p) for p in re.split(r"\.|\[|\]", label) if p != ""] return parts @@ -326,6 +327,7 @@ class GufeTokenizable(abc.ABC, metaclass=_ABCGufeClassMeta): This extra work in serializing is important for hashes that are stable *across different Python sessions*. """ + @classmethod def _schema_version(cls) -> int: return 1 @@ -346,9 +348,7 @@ def __hash__(self): return hash(self.key) def _gufe_tokenize(self): - """Return a list of normalized inputs for `gufe.base.tokenize`. - - """ + """Return a list of normalized inputs for `gufe.base.tokenize`.""" return tokenize(self) # return normalize(self.to_keyed_dict(include_defaults=False)) @@ -405,7 +405,7 @@ def serialization_migration(cls, old_dict, version): @property def logger(self): """Return logger adapter for this instance""" - if (adapter := getattr(self, '_logger', None)) is None: + if (adapter := getattr(self, "_logger", None)) is None: cls = self.__class__ logname = "gufekey." + cls.__module__ + "." + cls.__qualname__ logger = logging.getLogger(logname) @@ -415,7 +415,7 @@ def logger(self): @property def key(self): - if not hasattr(self, '_key') or self._key is None: + if not hasattr(self, "_key") or self._key is None: prefix = self.__class__.__qualname__ token = self._gufe_tokenize() self._key = GufeKey(f"{prefix}-{token}") @@ -433,7 +433,7 @@ def _set_key(self, key: str): key : str contents of the GufeKey for this object """ - if old_key := getattr(self, '_key', None): + if old_key := getattr(self, "_key", None): TOKENIZABLE_REGISTRY.pop(old_key) self._key = GufeKey(key) @@ -448,7 +448,7 @@ def defaults(cls): """ defaults = cls._defaults() - defaults[':version:'] = cls._schema_version() + defaults[":version:"] = cls._schema_version() return defaults @classmethod @@ -461,7 +461,8 @@ def _defaults(cls): sig = inspect.signature(cls.__init__) defaults = { - param.name: param.default for param in sig.parameters.values() + param.name: param.default + for param in sig.parameters.values() if param.default is not inspect.Parameter.empty } @@ -617,13 +618,12 @@ def copy_with_replacements(self, **replacements): """ dct = self._to_dict() if invalid := set(replacements) - set(dct): - raise TypeError(f"Invalid replacement keys: {invalid}. " - f"Allowed keys are: {set(dct)}") + raise TypeError(f"Invalid replacement keys: {invalid}. " f"Allowed keys are: {set(dct)}") dct.update(replacements) return self._from_dict(dct) - def to_keyed_chain(self) -> List[Tuple[str, Dict]]: + def to_keyed_chain(self) -> list[tuple[str, dict]]: """ Generate a keyed chain representation of the object. @@ -634,7 +634,7 @@ def to_keyed_chain(self) -> List[Tuple[str, Dict]]: return KeyedChain.gufe_to_keyed_chain_rep(self) @classmethod - def from_keyed_chain(cls, keyed_chain: List[Tuple[str, Dict]]): + def from_keyed_chain(cls, keyed_chain: list[tuple[str, dict]]): """ Generate an instance from keyed chain representation. @@ -674,6 +674,7 @@ def to_json(self, file: Optional[PathLike | TextIO] = None) -> None | str: return json.dumps(self.to_keyed_chain(), cls=JSON_HANDLER.encoder) from gufe.utils import ensure_filelike + with ensure_filelike(file, mode="w") as out: json.dump(self.to_keyed_chain(), out, cls=JSON_HANDLER.encoder) @@ -708,6 +709,7 @@ def from_json(cls, file: Optional[PathLike | TextIO] = None, content: Optional[s return cls.from_keyed_chain(keyed_chain=keyed_chain) from gufe.utils import ensure_filelike + with ensure_filelike(file, mode="r") as f: keyed_chain = json.load(f, cls=JSON_HANDLER.decoder) @@ -715,26 +717,24 @@ def from_json(cls, file: Optional[PathLike | TextIO] = None, content: Optional[s class GufeKey(str): - def __repr__(self): # pragma: no cover + def __repr__(self): # pragma: no cover return f"" def to_dict(self): - return {':gufe-key:': str(self)} + return {":gufe-key:": str(self)} @property def prefix(self) -> str: """Commonly indicates a classname""" - return self.split('-')[0] + return self.split("-")[0] @property def token(self) -> str: """Unique hash of this key, typically a md5 value""" - return self.split('-')[1] + return self.split("-")[1] -def gufe_objects_from_shallow_dict( - obj: Union[List, Dict, GufeTokenizable] -) -> List[GufeTokenizable]: +def gufe_objects_from_shallow_dict(obj: Union[list, dict, GufeTokenizable]) -> list[GufeTokenizable]: """Find GufeTokenizables within a shallow dict. This function recursively looks through the list/dict structures encoding @@ -759,16 +759,10 @@ def gufe_objects_from_shallow_dict( return [obj] elif isinstance(obj, list): - return list( - chain.from_iterable([gufe_objects_from_shallow_dict(item) for item in obj]) - ) + return list(chain.from_iterable([gufe_objects_from_shallow_dict(item) for item in obj])) elif isinstance(obj, dict): - return list( - chain.from_iterable( - [gufe_objects_from_shallow_dict(item) for item in obj.values()] - ) - ) + return list(chain.from_iterable([gufe_objects_from_shallow_dict(item) for item in obj.values()])) return [] @@ -811,7 +805,7 @@ def add_edges(o): return graph -class KeyedChain(object): +class KeyedChain: """Keyed chain representation encoder of a GufeTokenizable. The keyed chain representation of a GufeTokenizable provides a @@ -859,25 +853,25 @@ def from_gufe(cls, gufe_object: GufeTokenizable) -> Self: def to_gufe(self) -> GufeTokenizable: """Initialize a GufeTokenizable.""" - gts: Dict[str, GufeTokenizable] = {} + gts: dict[str, GufeTokenizable] = {} for gufe_key, keyed_dict in self: gt = key_decode_dependencies(keyed_dict, registry=gts) gts[gufe_key] = gt return gt @classmethod - def from_keyed_chain_rep(cls, keyed_chain: List[Tuple[str, Dict]]) -> Self: + def from_keyed_chain_rep(cls, keyed_chain: list[tuple[str, dict]]) -> Self: """Initialize a KeyedChain from a keyed chain representation.""" return cls(keyed_chain) - def to_keyed_chain_rep(self) -> List[Tuple[str, Dict]]: + def to_keyed_chain_rep(self) -> list[tuple[str, dict]]: """Return the keyed chain representation of this object.""" return list(self) @staticmethod def gufe_to_keyed_chain_rep( gufe_object: GufeTokenizable, - ) -> List[Tuple[str, Dict]]: + ) -> list[tuple[str, dict]]: """Create the keyed chain representation of a GufeTokenizable. This represents the GufeTokenizable as a list of two-element tuples @@ -897,8 +891,7 @@ def gufe_to_keyed_chain_rep( """ key_and_keyed_dicts = [ - (str(gt.key), gt.to_keyed_dict()) - for gt in nx.topological_sort(gufe_to_digraph(gufe_object)) + (str(gt.key), gt.to_keyed_dict()) for gt in nx.topological_sort(gufe_to_digraph(gufe_object)) ][::-1] return key_and_keyed_dicts @@ -907,7 +900,7 @@ def gufe_keys(self) -> Generator[str, None, None]: for key, _ in self: yield key - def keyed_dicts(self) -> Generator[Dict, None, None]: + def keyed_dicts(self) -> Generator[dict, None, None]: """Create a generator that iterates over the keyed dicts in the KeyedChain.""" for _, _dict in self: yield _dict @@ -936,8 +929,10 @@ def __getitem__(self, index): def module_qualname(obj): - return {'__qualname__': obj.__class__.__qualname__, - '__module__': obj.__class__.__module__} + return { + "__qualname__": obj.__class__.__qualname__, + "__module__": obj.__class__.__module__, + } def is_gufe_obj(obj: Any): @@ -945,8 +940,7 @@ def is_gufe_obj(obj: Any): def is_gufe_dict(dct: Any): - return (isinstance(dct, dict) and '__qualname__' in dct - and '__module__' in dct) + return isinstance(dct, dict) and "__qualname__" in dct and "__module__" in dct def is_gufe_key_dict(dct: Any): @@ -956,14 +950,15 @@ def is_gufe_key_dict(dct: Any): # conveniences to get a class from module/class name def import_qualname(modname: str, qualname: str, remappings=REMAPPED_CLASSES): if (qualname is None) or (modname is None): - raise ValueError("`__qualname__` or `__module__` cannot be None; " - f"unable to identify object {modname}.{qualname}") + raise ValueError( + "`__qualname__` or `__module__` cannot be None; " f"unable to identify object {modname}.{qualname}" + ) if (modname, qualname) in remappings: modname, qualname = remappings[(modname, qualname)] result = importlib.import_module(modname) - for name in qualname.split('.'): + for name in qualname.split("."): result = getattr(result, name) return result @@ -1000,18 +995,16 @@ def modify_dependencies(obj: Union[dict, list], modifier, is_mine, mode, top=Tru If `True`, skip modifying `obj` itself; needed for recursive use to avoid early stopping on `obj`. """ - if is_mine(obj) and not top and mode == 'encode': + if is_mine(obj) and not top and mode == "encode": obj = modifier(obj) if isinstance(obj, dict): - obj = {key: modify_dependencies(value, modifier, is_mine, mode=mode, top=False) - for key, value in obj.items()} + obj = {key: modify_dependencies(value, modifier, is_mine, mode=mode, top=False) for key, value in obj.items()} elif isinstance(obj, list): - obj = [modify_dependencies(item, modifier, is_mine, mode=mode, top=False) - for item in obj] + obj = [modify_dependencies(item, modifier, is_mine, mode=mode, top=False) for item in obj] - if is_mine(obj) and not top and mode == 'decode': + if is_mine(obj) and not top and mode == "decode": obj = modifier(obj) return obj @@ -1021,18 +1014,12 @@ def modify_dependencies(obj: Union[dict, list], modifier, is_mine, mode, top=Tru def to_dict(obj: GufeTokenizable) -> dict: dct = obj._to_dict() dct.update(module_qualname(obj)) - dct[':version:'] = obj._schema_version() + dct[":version:"] = obj._schema_version() return dct def dict_encode_dependencies(obj: GufeTokenizable) -> dict: - return modify_dependencies( - obj.to_shallow_dict(), - to_dict, - is_gufe_obj, - mode='encode', - top=True - ) + return modify_dependencies(obj.to_shallow_dict(), to_dict, is_gufe_obj, mode="encode", top=True) def key_encode_dependencies(obj: GufeTokenizable) -> dict: @@ -1040,8 +1027,8 @@ def key_encode_dependencies(obj: GufeTokenizable) -> dict: obj.to_shallow_dict(), lambda obj: obj.key.to_dict(), is_gufe_obj, - mode='encode', - top=True + mode="encode", + top=True, ) @@ -1064,9 +1051,9 @@ def from_dict(dct) -> GufeTokenizable: def _from_dict(dct: dict) -> GufeTokenizable: - module = dct.pop('__module__') - qualname = dct.pop('__qualname__') - version = dct.pop(':version:', 1) + module = dct.pop("__module__") + qualname = dct.pop("__qualname__") + version = dct.pop(":version:", 1) cls = get_class(module, qualname) dct = cls.serialization_migration(dct, version) @@ -1074,23 +1061,18 @@ def _from_dict(dct: dict) -> GufeTokenizable: def dict_decode_dependencies(dct: dict) -> GufeTokenizable: - return from_dict( - modify_dependencies(dct, from_dict, is_gufe_dict, mode='decode', top=True) - ) + return from_dict(modify_dependencies(dct, from_dict, is_gufe_dict, mode="decode", top=True)) -def key_decode_dependencies( - dct: dict, - registry=TOKENIZABLE_REGISTRY -) -> GufeTokenizable: +def key_decode_dependencies(dct: dict, registry=TOKENIZABLE_REGISTRY) -> GufeTokenizable: # this version requires that all dependent objects are already registered # responsibility of the storage system that uses this to do so dct = modify_dependencies( dct, lambda d: registry[GufeKey(d[":gufe-key:"])], is_gufe_key_dict, - mode='decode', - top=True + mode="decode", + top=True, ) return from_dict(dct) @@ -1111,12 +1093,12 @@ def get_all_gufe_objs(obj): all contained GufeTokenizables """ results = {obj} + def modifier(o): results.add(o) return o.to_shallow_dict() - _ = modify_dependencies(obj.to_shallow_dict(), modifier, is_gufe_obj, - mode='encode') + _ = modify_dependencies(obj.to_shallow_dict(), modifier, is_gufe_obj, mode="encode") return results @@ -1136,7 +1118,10 @@ def tokenize(obj: GufeTokenizable) -> str: """ # hasher = hashlib.md5(str(normalize(obj)).encode(), usedforsecurity=False) - dumped = json.dumps(obj.to_keyed_dict(include_defaults=False), - sort_keys=True, cls=JSON_HANDLER.encoder) + dumped = json.dumps( + obj.to_keyed_dict(include_defaults=False), + sort_keys=True, + cls=JSON_HANDLER.encoder, + ) hasher = hashlib.md5(dumped.encode(), usedforsecurity=False) return hasher.hexdigest() diff --git a/gufe/transformations/__init__.py b/gufe/transformations/__init__.py index fe016384..2bb0e7b5 100644 --- a/gufe/transformations/__init__.py +++ b/gufe/transformations/__init__.py @@ -1,2 +1,3 @@ """A chemical system and protocol combined form a Transformation""" -from .transformation import Transformation, NonTransformation + +from .transformation import NonTransformation, Transformation diff --git a/gufe/transformations/transformation.py b/gufe/transformations/transformation.py index e018fdc5..e8df7a3b 100644 --- a/gufe/transformations/transformation.py +++ b/gufe/transformations/transformation.py @@ -2,27 +2,34 @@ # For details, see https://github.com/OpenFreeEnergy/gufe import abc -from typing import Optional, Iterable, Union import json import warnings - -from ..tokenization import GufeTokenizable, JSON_HANDLER -from ..utils import ensure_filelike +from collections.abc import Iterable +from typing import Optional, Union from ..chemicalsystem import ChemicalSystem -from ..protocols import Protocol, ProtocolDAG, ProtocolResult, ProtocolDAGResult from ..mapping import ComponentMapping +from ..protocols import Protocol, ProtocolDAG, ProtocolDAGResult, ProtocolResult +from ..tokenization import JSON_HANDLER, GufeTokenizable +from ..utils import ensure_filelike class TransformationBase(GufeTokenizable): - """Transformation base class. - - """ def __init__( self, protocol: Protocol, name: Optional[str] = None, ): + """Transformation base class. + + Parameters + ---------- + protocol : Protocol + The sampling method to use for the transformation. + name : str, optional + A human-readable name for this transformation. + + """ self._protocol = protocol self._name = name @@ -32,11 +39,10 @@ def _defaults(cls): @property def name(self) -> Optional[str]: - """ - Optional identifier for the transformation; used as part of its hash. + """Optional identifier for the transformation; used as part of its hash. Set this to a unique value if adding multiple, otherwise identical - transformations to the same :class:`AlchemicalNetwork` to avoid + transformations to the same :class:`.AlchemicalNetwork` to avoid deduplication. """ return self._name @@ -45,6 +51,18 @@ def name(self) -> Optional[str]: def _from_dict(cls, d: dict): return cls(**d) + @property + @abc.abstractmethod + def stateA(self) -> ChemicalSystem: + """The starting :class:`.ChemicalSystem` for the transformation.""" + raise NotImplementedError + + @property + @abc.abstractmethod + def stateB(self) -> ChemicalSystem: + """The ending :class:`.ChemicalSystem` for the transformation.""" + raise NotImplementedError + @abc.abstractmethod def create( self, @@ -53,27 +71,25 @@ def create( name: Optional[str] = None, ) -> ProtocolDAG: """ - Returns a ``ProtocolDAG`` executing this ``Transformation.protocol``. + Returns a :class:`.ProtocolDAG` executing this ``Transformation.protocol``. """ raise NotImplementedError - def gather( - self, protocol_dag_results: Iterable[ProtocolDAGResult] - ) -> ProtocolResult: - """ - Gather multiple ``ProtocolDAGResult`` into a single ``ProtocolResult``. + def gather(self, protocol_dag_results: Iterable[ProtocolDAGResult]) -> ProtocolResult: + """Gather multiple :class:`.ProtocolDAGResult` \s into a single + :class:`.ProtocolResult`. Parameters ---------- protocol_dag_results : Iterable[ProtocolDAGResult] - The ``ProtocolDAGResult`` objects to assemble aggregate quantities - from. + The :class:`.ProtocolDAGResult` objects to assemble aggregate + quantities from. Returns ------- ProtocolResult - Aggregated results from many ``ProtocolDAGResult`` objects, all from - a given ``Protocol``. + Aggregated results from many :class:`.ProtocolDAGResult` objects, + all from a given :class:`.Protocol`. """ return self.protocol.gather(protocol_dag_results=protocol_dag_results) @@ -81,18 +97,17 @@ def gather( def dump(self, file): """Dump this Transformation to a JSON file. - Note that this is not space-efficient: for example, any - ``Component`` which is used in both ``ChemicalSystem`` objects will be - represented twice in the JSON output. + Note that this is not space-efficient: for example, any ``Component`` + which is used in both ``ChemicalSystem`` objects will be represented + twice in the JSON output. Parameters ---------- file : Union[PathLike, FileLike] - a pathlike of filelike to save this transformation to. + A pathlike of filelike to save this transformation to. """ - with ensure_filelike(file, mode='w') as f: - json.dump(self.to_dict(), f, cls=JSON_HANDLER.encoder, - sort_keys=True) + with ensure_filelike(file, mode="w") as f: + json.dump(self.to_dict(), f, cls=JSON_HANDLER.encoder, sort_keys=True) @classmethod def load(cls, file): @@ -101,9 +116,9 @@ def load(cls, file): Parameters ---------- file : Union[PathLike, FileLike] - a pathlike or filelike to read this transformation from + A pathlike or filelike to read this transformation from. """ - with ensure_filelike(file, mode='r') as f: + with ensure_filelike(file, mode="r") as f: dct = json.load(f, cls=JSON_HANDLER.decoder) return cls.from_dict(dct) @@ -124,25 +139,27 @@ def __init__( mapping: Optional[Union[ComponentMapping, list[ComponentMapping], dict[str, ComponentMapping]]] = None, name: Optional[str] = None, ): - """Two chemical states with a method for estimating free energy difference + """Two chemical states with a method for estimating the free energy + difference between them. - Connects two :class:`.ChemicalSystem` objects, with directionality, - and relates this to a :class:`.Protocol` which will provide an estimate of - the free energy difference of moving between these systems. - Used as an edge of an :class:`.AlchemicalNetwork`. + Connects two :class:`.ChemicalSystem` objects, with directionality, and + relates these to a :class:`.Protocol` which will provide an estimate of + the free energy difference between these systems. Used as an edge of an + :class:`.AlchemicalNetwork`. Parameters ---------- - stateA, stateB: ChemicalSystem - The start (A) and end (B) states of the transformation - protocol: Protocol - The method used to estimate the free energy difference between states - A and B + stateA, stateB : ChemicalSystem + The start (A) and end (B) states of the transformation. + protocol : Protocol + The method used to estimate the free energy difference between + states A and B. mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] - the details of any transformations between :class:`.Component` \s of - the two states + The details of any transformations between :class:`.Component` \s + of the two states. name : str, optional - a human-readable tag for this transformation + A human-readable name for this transformation. + """ if isinstance(mapping, dict): warnings.warn(("mapping input as a dict is deprecated, " @@ -157,9 +174,7 @@ def __init__( self._name = name def __repr__(self): - attrs = ['name', 'stateA', 'stateB', 'protocol', 'mapping'] - content = ", ".join([f"{i}={getattr(self, i)}" for i in attrs]) - return f"{self.__class__.__name__}({content})" + return f"{self.__class__.__name__}(stateA={self.stateA}, " f"stateB={self.stateB}, protocol={self.protocol})" @property def stateA(self) -> ChemicalSystem: @@ -173,7 +188,7 @@ def stateB(self) -> ChemicalSystem: @property def protocol(self) -> Protocol: - """The protocol used to perform the transformation. + """The :class:`.Protocol` used to perform the transformation. This protocol estimates the free energy differences between ``stateA`` and ``stateB`` :class:`.ChemicalSystem` objects. It includes all details @@ -216,17 +231,6 @@ def create( class NonTransformation(TransformationBase): - """A non-alchemical edge of an alchemical network. - - A "transformation" that performs no transformation at all. - Technically a self-loop, or an edge with the same ``ChemicalSystem`` at - either end. - - Functionally used for applying a dynamics protocol to a ``ChemicalSystem`` - that performs no alchemical transformation at all. This allows e.g. - equilibrium MD to be performed on a ``ChemicalSystem`` as desired alongside - alchemical protocols between it and and other ``ChemicalSystem`` objects. - """ def __init__( self, @@ -234,6 +238,27 @@ def __init__( protocol: Protocol, name: Optional[str] = None, ): + """A non-alchemical edge of an alchemical network. + + A "transformation" that performs no transformation at all. + Technically a self-loop, or an edge with the same ``ChemicalSystem`` at + either end. + + Functionally used for applying a dynamics protocol to a ``ChemicalSystem`` + that performs no alchemical transformation at all. This allows e.g. + equilibrium MD to be performed on a ``ChemicalSystem`` as desired alongside + alchemical protocols between it and and other ``ChemicalSystem`` objects. + + Parameters + ---------- + system : ChemicalSystem + The (identical) end states of the "transformation" to be sampled + protocol : Protocol + The sampling method to use on the ``system`` + name : str, optional + A human-readable name for this transformation. + + """ self._system = system self._protocol = protocol @@ -245,21 +270,31 @@ def __repr__(self): return f"{self.__class__.__name__}({content})" @property - def stateA(self): + def stateA(self) -> ChemicalSystem: + """The :class:`.ChemicalSystem` this "transformation" samples. + + Synonomous with ``system`` attribute. + + """ return self._system @property - def stateB(self): + def stateB(self) -> ChemicalSystem: + """The :class:`.ChemicalSystem` this "transformation" samples. + + Synonomous with ``system`` attribute. + + """ return self._system @property def system(self) -> ChemicalSystem: + """The :class:`.ChemicalSystem` this "transformation" samples.""" return self._system @property def protocol(self): - """ - The protocol for sampling dynamics of the `ChemicalSystem`. + """The :class:`.Protocol` for sampling dynamics of the ``system``. Includes all details needed to perform required simulations/calculations. @@ -282,6 +317,9 @@ def create( """ Returns a ``ProtocolDAG`` executing this ``NonTransformation.protocol``. """ + # TODO: once we have an implicit component mapping concept, use this + # here instead of None to allow use of alchemical protocols with + # NonTransformations return self.protocol.create( stateA=self.system, stateB=self.system, diff --git a/gufe/utils.py b/gufe/utils.py index f9d3b0ff..ea9ad861 100644 --- a/gufe/utils.py +++ b/gufe/utils.py @@ -22,13 +22,13 @@ class ensure_filelike: the stream will always be closed. Filelike inputs will close if this parameter is True. """ + def __init__(self, fn, mode=None, force_close=False): filelikes = (io.TextIOBase, io.RawIOBase, io.BufferedIOBase) if isinstance(fn, filelikes): if mode is not None: warnings.warn( - f"mode='{mode}' specified with {fn.__class__.__name__}." - " User-specified mode will be ignored." + f"mode='{mode}' specified with {fn.__class__.__name__}." " User-specified mode will be ignored." ) self.to_open = None self.do_close = force_close @@ -51,4 +51,3 @@ def __enter__(self): def __exit__(self, type, value, traceback): if self.do_close: self.context.close() - diff --git a/gufe/vendor/pdb_file/PdbxContainers.py b/gufe/vendor/pdb_file/PdbxContainers.py index adedadc2..e8bcd94d 100644 --- a/gufe/vendor/pdb_file/PdbxContainers.py +++ b/gufe/vendor/pdb_file/PdbxContainers.py @@ -5,7 +5,7 @@ # # Update: # 23-Mar-2011 jdw Added method to rename attributes in category containers. -# 05-Apr-2011 jdw Change cif writer to select double quoting as preferred +# 05-Apr-2011 jdw Change cif writer to select double quoting as preferred # quoting style where possible. # 16-Jan-2012 jdw Create base class for DataCategory class # 22-Mar-2012 jdw when append attributes to existing categories update @@ -36,29 +36,31 @@ data and definition meta data. """ -from __future__ import absolute_import __docformat__ = "restructuredtext en" -__author__ = "John Westbrook" -__email__ = "jwest@rcsb.rutgers.edu" -__license__ = "Creative Commons Attribution 3.0 Unported" -__version__ = "V0.01" +__author__ = "John Westbrook" +__email__ = "jwest@rcsb.rutgers.edu" +__license__ = "Creative Commons Attribution 3.0 Unported" +__version__ = "V0.01" -import re,sys,traceback +import re +import sys +import traceback + + +class CifName: + """Class of utilities for CIF-style data names -""" -class CifName(object): - ''' Class of utilities for CIF-style data names - - ''' def __init__(self): pass @staticmethod def categoryPart(name): - tname="" + tname = "" if name.startswith("_"): - tname=name[1:] + tname = name[1:] else: - tname=name + tname = name i = tname.find(".") if i == -1: @@ -72,40 +74,40 @@ def attributePart(name): if i == -1: return None else: - return name[i+1:] + return name[i + 1 :] -class ContainerBase(object): - ''' Container base class for data and definition objects. - ''' - def __init__(self,name): +class ContainerBase: + """Container base class for data and definition objects.""" + + def __init__(self, name): # The enclosing scope of the data container (e.g. data_/save_) self.__name = name - # List of category names within this container - - self.__objNameList=[] + # List of category names within this container - + self.__objNameList = [] # dictionary of DataCategory objects keyed by category name. - self.__objCatalog={} - self.__type=None + self.__objCatalog = {} + self.__type = None def getType(self): return self.__type - def setType(self,type): - self.__type=type - + def setType(self, type): + self.__type = type + def getName(self): return self.__name - def setName(self,name): - self.__name=name + def setName(self, name): + self.__name = name - def exists(self,name): + def exists(self, name): if name in self.__objCatalog: return True else: return False - - def getObj(self,name): + + def getObj(self, name): if name in self.__objCatalog: return self.__objCatalog[name] else: @@ -113,55 +115,51 @@ def getObj(self,name): def getObjNameList(self): return self.__objNameList - - def append(self,obj): - """ Add the input object to the current object catalog. An existing object - of the same name will be overwritten. + + def append(self, obj): + """Add the input object to the current object catalog. An existing object + of the same name will be overwritten. """ if obj.getName() is not None: if obj.getName() not in self.__objCatalog: - # self.__objNameList is keeping track of object order here -- + # self.__objNameList is keeping track of object order here -- self.__objNameList.append(obj.getName()) - self.__objCatalog[obj.getName()]=obj + self.__objCatalog[obj.getName()] = obj - def replace(self,obj): - """ Replace an existing object with the input object - """ - if ((obj.getName() is not None) and (obj.getName() in self.__objCatalog) ): - self.__objCatalog[obj.getName()]=obj - + def replace(self, obj): + """Replace an existing object with the input object""" + if (obj.getName() is not None) and (obj.getName() in self.__objCatalog): + self.__objCatalog[obj.getName()] = obj - def printIt(self,fh=sys.stdout,type="brief"): - fh.write("+ %s container: %30s contains %4d categories\n" % - (self.getType(),self.getName(),len(self.__objNameList))) + def printIt(self, fh=sys.stdout, type="brief"): + fh.write( + "+ %s container: %30s contains %4d categories\n" % (self.getType(), self.getName(), len(self.__objNameList)) + ) for nm in self.__objNameList: fh.write("--------------------------------------------\n") fh.write("Data category: %s\n" % nm) - if type == 'brief': + if type == "brief": self.__objCatalog[nm].printIt(fh) else: self.__objCatalog[nm].dumpIt(fh) - - def rename(self,curName,newName): - """ Change the name of an object in place - - """ + def rename(self, curName, newName): + """Change the name of an object in place -""" try: - i=self.__objNameList.index(curName) - self.__objNameList[i]=newName - self.__objCatalog[newName]=self.__objCatalog[curName] + i = self.__objNameList.index(curName) + self.__objNameList[i] = newName + self.__objCatalog[newName] = self.__objCatalog[curName] self.__objCatalog[newName].setName(newName) return True except: return False - def remove(self,curName): - """ Revmove object by name. Return True on success or False otherwise. - """ + def remove(self, curName): + """Revmove object by name. Return True on success or False otherwise.""" try: if curName in self.__objCatalog: del self.__objCatalog[curName] - i=self.__objNameList.index(curName) + i = self.__objNameList.index(curName) del self.__objNameList[i] return True else: @@ -171,171 +169,180 @@ def remove(self,curName): return False - + class DefinitionContainer(ContainerBase): - def __init__(self,name): - super(DefinitionContainer,self).__init__(name) - self.setType('definition') + def __init__(self, name): + super().__init__(name) + self.setType("definition") def isCategory(self): - if self.exists('category'): + if self.exists("category"): return True return False def isAttribute(self): - if self.exists('item'): + if self.exists("item"): return True return False - - def printIt(self,fh=sys.stdout,type="brief"): - fh.write("Definition container: %30s contains %4d categories\n" % - (self.getName(),len(self.getObjNameList()))) + def printIt(self, fh=sys.stdout, type="brief"): + fh.write("Definition container: %30s contains %4d categories\n" % (self.getName(), len(self.getObjNameList()))) if self.isCategory(): fh.write("Definition type: category\n") elif self.isAttribute(): fh.write("Definition type: item\n") else: - fh.write("Definition type: undefined\n") + fh.write("Definition type: undefined\n") for nm in self.getObjNameList(): fh.write("--------------------------------------------\n") fh.write("Definition category: %s\n" % nm) - if type == 'brief': + if type == "brief": self.getObj(nm).printIt(fh) else: self.getObj(nm).dumpId(fh) - class DataContainer(ContainerBase): - ''' Container class for DataCategory objects. - ''' - def __init__(self,name): - super(DataContainer,self).__init__(name) - self.setType('data') - self.__globalFlag=False - - def invokeDataBlockMethod(self,type,method,db): + """Container class for DataCategory objects.""" + + def __init__(self, name): + super().__init__(name) + self.setType("data") + self.__globalFlag = False + + def invokeDataBlockMethod(self, type, method, db): self.__currentRow = 1 exec(method.getInline()) def setGlobal(self): - self.__globalFlag=True + self.__globalFlag = True def getGlobal(self): - return self.__globalFlag + return self.__globalFlag + - +class DataCategoryBase: + """Base object definition for a data category -""" -class DataCategoryBase(object): - """ Base object definition for a data category - - """ - def __init__(self,name,attributeNameList=None,rowList=None): + def __init__(self, name, attributeNameList=None, rowList=None): self._name = name # if rowList is not None: - self._rowList=rowList + self._rowList = rowList else: - self._rowList=[] + self._rowList = [] if attributeNameList is not None: - self._attributeNameList=attributeNameList + self._attributeNameList = attributeNameList else: - self._attributeNameList=[] + self._attributeNameList = [] # # Derived class data - # - self._catalog={} - self._numAttributes=0 + self._catalog = {} + self._numAttributes = 0 # self.__setup() def __setup(self): self._numAttributes = len(self._attributeNameList) - self._catalog={} + self._catalog = {} for attributeName in self._attributeNameList: - attributeNameLC = attributeName.lower() + attributeNameLC = attributeName.lower() self._catalog[attributeNameLC] = attributeName + # - def setRowList(self,rowList): - self._rowList=rowList + def setRowList(self, rowList): + self._rowList = rowList - def setAttributeNameList(self,attributeNameList): - self._attributeNameList=attributeNameList + def setAttributeNameList(self, attributeNameList): + self._attributeNameList = attributeNameList self.__setup() - def setName(self,name): - self._name=name + def setName(self, name): + self._name = name def get(self): - return (self._name,self._attributeNameList,self._rowList) + return (self._name, self._attributeNameList, self._rowList) + - - class DataCategory(DataCategoryBase): - """ Methods for creating, accessing, and formatting PDBx cif data categories. - """ - def __init__(self,name,attributeNameList=None,rowList=None): - super(DataCategory,self).__init__(name,attributeNameList,rowList) + """Methods for creating, accessing, and formatting PDBx cif data categories.""" + + def __init__(self, name, attributeNameList=None, rowList=None): + super().__init__(name, attributeNameList, rowList) # self.__lfh = sys.stdout - - self.__currentRowIndex=0 - self.__currentAttribute=None + + self.__currentRowIndex = 0 + self.__currentAttribute = None # - self.__avoidEmbeddedQuoting=False + self.__avoidEmbeddedQuoting = False # # -------------------------------------------------------------------- - # any whitespace - self.__wsRe=re.compile(r"\s") - self.__wsAndQuotesRe=re.compile(r"[\s'\"]") + # any whitespace + self.__wsRe = re.compile(r"\s") + self.__wsAndQuotesRe = re.compile(r"[\s'\"]") # any newline or carriage control - self.__nlRe=re.compile(r"[\n\r]") + self.__nlRe = re.compile(r"[\n\r]") # - # single quote - self.__sqRe=re.compile(r"[']") + # single quote + self.__sqRe = re.compile(r"[']") # - self.__sqWsRe=re.compile(r"('\s)|(\s')") - - # double quote - self.__dqRe=re.compile(r'["]') - self.__dqWsRe=re.compile(r'("\s)|(\s")') + self.__sqWsRe = re.compile(r"('\s)|(\s')") + + # double quote + self.__dqRe = re.compile(r'["]') + self.__dqWsRe = re.compile(r'("\s)|(\s")') # - self.__intRe=re.compile(r'^[0-9]+$') - self.__floatRe=re.compile(r'^-?(([0-9]+)[.]?|([0-9]*[.][0-9]+))([(][0-9]+[)])?([eE][+-]?[0-9]+)?$') + self.__intRe = re.compile(r"^[0-9]+$") + self.__floatRe = re.compile(r"^-?(([0-9]+)[.]?|([0-9]*[.][0-9]+))([(][0-9]+[)])?([eE][+-]?[0-9]+)?$") # - self.__dataTypeList=['DT_NULL_VALUE','DT_INTEGER','DT_FLOAT','DT_UNQUOTED_STRING','DT_ITEM_NAME', - 'DT_DOUBLE_QUOTED_STRING','DT_SINGLE_QUOTED_STRING','DT_MULTI_LINE_STRING'] - self.__formatTypeList=['FT_NULL_VALUE','FT_NUMBER','FT_NUMBER','FT_UNQUOTED_STRING', - 'FT_QUOTED_STRING','FT_QUOTED_STRING','FT_QUOTED_STRING','FT_MULTI_LINE_STRING'] + self.__dataTypeList = [ + "DT_NULL_VALUE", + "DT_INTEGER", + "DT_FLOAT", + "DT_UNQUOTED_STRING", + "DT_ITEM_NAME", + "DT_DOUBLE_QUOTED_STRING", + "DT_SINGLE_QUOTED_STRING", + "DT_MULTI_LINE_STRING", + ] + self.__formatTypeList = [ + "FT_NULL_VALUE", + "FT_NUMBER", + "FT_NUMBER", + "FT_UNQUOTED_STRING", + "FT_QUOTED_STRING", + "FT_QUOTED_STRING", + "FT_QUOTED_STRING", + "FT_MULTI_LINE_STRING", + ] # - def __getitem__(self, x): - """ Implements list-type functionality - - Implements op[x] for some special cases - - x=integer - returns the row in category (normal list behavior) - x=string - returns the value of attribute 'x' in first row. + """Implements list-type functionality - + Implements op[x] for some special cases - + x=integer - returns the row in category (normal list behavior) + x=string - returns the value of attribute 'x' in first row. """ if isinstance(x, int): - #return self._rowList.__getitem__(x) + # return self._rowList.__getitem__(x) return self._rowList[x] elif isinstance(x, str): try: - #return self._rowList[0][x] - ii=self.getAttributeIndex(x) + # return self._rowList[0][x] + ii = self.getAttributeIndex(x) return self._rowList[0][ii] except (IndexError, KeyError): raise KeyError raise TypeError(x) - - + def getCurrentAttribute(self): return self.__currentAttribute - def getRowIndex(self): return self.__currentRowIndex @@ -343,20 +350,20 @@ def getRowList(self): return self._rowList def getRowCount(self): - return (len(self._rowList)) + return len(self._rowList) - def getRow(self,index): + def getRow(self, index): try: return self._rowList[index] except: return [] - def removeRow(self,index): + def removeRow(self, index): try: - if ((index >= 0) and (index < len(self._rowList))): - del self._rowList[index] + if (index >= 0) and (index < len(self._rowList)): + del self._rowList[index] if self.__currentRowIndex >= len(self._rowList): - self.__currentRowIndex = len(self._rowList) -1 + self.__currentRowIndex = len(self._rowList) - 1 return True else: pass @@ -365,20 +372,19 @@ def removeRow(self,index): return False - def getFullRow(self,index): - """ Return a full row based on the length of the the attribute list. - """ + def getFullRow(self, index): + """Return a full row based on the length of the the attribute list.""" try: - if (len(self._rowList[index]) < self._numAttributes): - for ii in range( self._numAttributes-len(self._rowList[index])): - self._rowList[index].append('?') + if len(self._rowList[index]) < self._numAttributes: + for ii in range(self._numAttributes - len(self._rowList[index])): + self._rowList[index].append("?") return self._rowList[index] except: - return ['?' for ii in range(self._numAttributes)] + return ["?" for ii in range(self._numAttributes)] def getName(self): return self._name - + def getAttributeList(self): return self._attributeNameList @@ -386,54 +392,53 @@ def getAttributeCount(self): return len(self._attributeNameList) def getAttributeListWithOrder(self): - oL=[] - for ii,att in enumerate(self._attributeNameList): - oL.append((att,ii)) + oL = [] + for ii, att in enumerate(self._attributeNameList): + oL.append((att, ii)) return oL - def getAttributeIndex(self,attributeName): + def getAttributeIndex(self, attributeName): try: return self._attributeNameList.index(attributeName) except: return -1 - def hasAttribute(self,attributeName): + def hasAttribute(self, attributeName): return attributeName in self._attributeNameList - - def getIndex(self,attributeName): + + def getIndex(self, attributeName): try: return self._attributeNameList.index(attributeName) except: return -1 def getItemNameList(self): - itemNameList=[] + itemNameList = [] for att in self._attributeNameList: - itemNameList.append("_"+self._name+"."+att) + itemNameList.append("_" + self._name + "." + att) return itemNameList - - def append(self,row): - #self.__lfh.write("PdbxContainer(append) category %s row %r\n" % (self._name,row)) + + def append(self, row): + # self.__lfh.write("PdbxContainer(append) category %s row %r\n" % (self._name,row)) self._rowList.append(row) - def appendAttribute(self,attributeName): - attributeNameLC = attributeName.lower() - if attributeNameLC in self._catalog: + def appendAttribute(self, attributeName): + attributeNameLC = attributeName.lower() + if attributeNameLC in self._catalog: i = self._attributeNameList.index(self._catalog[attributeNameLC]) self._attributeNameList[i] = attributeName self._catalog[attributeNameLC] = attributeName - #self.__lfh.write("Appending existing attribute %s\n" % attributeName) + # self.__lfh.write("Appending existing attribute %s\n" % attributeName) else: - #self.__lfh.write("Appending existing attribute %s\n" % attributeName) + # self.__lfh.write("Appending existing attribute %s\n" % attributeName) self._attributeNameList.append(attributeName) self._catalog[attributeNameLC] = attributeName # self._numAttributes = len(self._attributeNameList) - - def appendAttributeExtendRows(self,attributeName): - attributeNameLC = attributeName.lower() - if attributeNameLC in self._catalog: + def appendAttributeExtendRows(self, attributeName): + attributeNameLC = attributeName.lower() + if attributeNameLC in self._catalog: i = self._attributeNameList.index(self._catalog[attributeNameLC]) self._attributeNameList[i] = attributeName self._catalog[attributeNameLC] = attributeName @@ -442,15 +447,13 @@ def appendAttributeExtendRows(self,attributeName): self._attributeNameList.append(attributeName) self._catalog[attributeNameLC] = attributeName # add a placeholder to any existing rows for the new attribute. - if (len(self._rowList) > 0): + if len(self._rowList) > 0: for row in self._rowList: row.append("?") # self._numAttributes = len(self._attributeNameList) - - - def getValue(self,attributeName=None,rowIndex=None): + def getValue(self, attributeName=None, rowIndex=None): if attributeName is None: attribute = self.__currentAttribute else: @@ -458,363 +461,377 @@ def getValue(self,attributeName=None,rowIndex=None): if rowIndex is None: rowI = self.__currentRowIndex else: - rowI =rowIndex - - if isinstance(attribute, str) and isinstance(rowI,int): + rowI = rowIndex + + if isinstance(attribute, str) and isinstance(rowI, int): try: return self._rowList[rowI][self._attributeNameList.index(attribute)] - except (IndexError): - raise IndexError + except IndexError: + raise IndexError raise IndexError(attribute) - def setValue(self,value,attributeName=None,rowIndex=None): + def setValue(self, value, attributeName=None, rowIndex=None): if attributeName is None: - attribute=self.__currentAttribute + attribute = self.__currentAttribute else: - attribute=attributeName + attribute = attributeName if rowIndex is None: rowI = self.__currentRowIndex else: rowI = rowIndex - if isinstance(attribute, str) and isinstance(rowI,int): + if isinstance(attribute, str) and isinstance(rowI, int): try: # if row index is out of range - add the rows - - for ii in range(rowI+1 - len(self._rowList)): - self._rowList.append(self.__emptyRow()) + for ii in range(rowI + 1 - len(self._rowList)): + self._rowList.append(self.__emptyRow()) # self._rowList[rowI][attribute]=value - ll=len(self._rowList[rowI]) - ind=self._attributeNameList.index(attribute) - - # extend the list if needed - - if ( ind >= ll): - self._rowList[rowI].extend([None for ii in xrange(2*ind -ll)]) - self._rowList[rowI][ind]=value - except (IndexError): - self.__lfh.write("DataCategory(setvalue) index error category %s attribute %s index %d value %r\n" % - (self._name,attribute,rowI,value)) - traceback.print_exc(file=self.__lfh) - #raise IndexError - except (ValueError): - self.__lfh.write("DataCategory(setvalue) value error category %s attribute %s index %d value %r\n" % - (self._name,attribute,rowI,value)) - traceback.print_exc(file=self.__lfh) - #raise ValueError + ll = len(self._rowList[rowI]) + ind = self._attributeNameList.index(attribute) + + # extend the list if needed - + if ind >= ll: + self._rowList[rowI].extend([None for ii in xrange(2 * ind - ll)]) + self._rowList[rowI][ind] = value + except IndexError: + self.__lfh.write( + "DataCategory(setvalue) index error category %s attribute %s index %d value %r\n" + % (self._name, attribute, rowI, value) + ) + traceback.print_exc(file=self.__lfh) + # raise IndexError + except ValueError: + self.__lfh.write( + "DataCategory(setvalue) value error category %s attribute %s index %d value %r\n" + % (self._name, attribute, rowI, value) + ) + traceback.print_exc(file=self.__lfh) + # raise ValueError def __emptyRow(self): return [None for ii in range(len(self._attributeNameList))] - - def replaceValue(self,oldValue,newValue,attributeName): - numReplace=0 + + def replaceValue(self, oldValue, newValue, attributeName): + numReplace = 0 if attributeName not in self._attributeNameList: return numReplace - ind=self._attributeNameList.index(attributeName) + ind = self._attributeNameList.index(attributeName) for row in self._rowList: if row[ind] == oldValue: - row[ind]=newValue + row[ind] = newValue numReplace += 1 return numReplace - def replaceSubstring(self,oldValue,newValue,attributeName): - ok=False - if attributeName not in self._attributeNameList: + def replaceSubstring(self, oldValue, newValue, attributeName): + ok = False + if attributeName not in self._attributeNameList: return ok - ind=self._attributeNameList.index(attributeName) + ind = self._attributeNameList.index(attributeName) for row in self._rowList: - val=row[ind] - row[ind]=val.replace(oldValue,newValue) + val = row[ind] + row[ind] = val.replace(oldValue, newValue) if val != row[ind]: - ok=True + ok = True return ok - - def invokeAttributeMethod(self,attributeName,type,method,db): + + def invokeAttributeMethod(self, attributeName, type, method, db): self.__currentRowIndex = 0 - self.__currentAttribute=attributeName + self.__currentAttribute = attributeName self.appendAttribute(attributeName) - currentRowIndex=self.__currentRowIndex + currentRowIndex = self.__currentRowIndex # - ind=self._attributeNameList.index(attributeName) + ind = self._attributeNameList.index(attributeName) if len(self._rowList) == 0: - row=[None for ii in xrange(len(self._attributeNameList)*2)] - row[ind]=None + row = [None for ii in xrange(len(self._attributeNameList) * 2)] + row[ind] = None self._rowList.append(row) - + for row in self._rowList: ll = len(row) - if (ind >= ll): - row.extend([None for ii in xrange(2*ind-ll)]) - row[ind]=None + if ind >= ll: + row.extend([None for ii in xrange(2 * ind - ll)]) + row[ind] = None exec(method.getInline()) - self.__currentRowIndex+=1 - currentRowIndex=self.__currentRowIndex + self.__currentRowIndex += 1 + currentRowIndex = self.__currentRowIndex - def invokeCategoryMethod(self,type,method,db): + def invokeCategoryMethod(self, type, method, db): self.__currentRowIndex = 0 exec(method.getInline()) def getAttributeLengthMaximumList(self): - mList=[0 for i in len(self._attributeNameList)] + mList = [0 for i in len(self._attributeNameList)] for row in self._rowList: - for indx,val in enumerate(row): - mList[indx] = max(mList[indx],len(val)) + for indx, val in enumerate(row): + mList[indx] = max(mList[indx], len(val)) return mList - - def renameAttribute(self,curAttributeName,newAttributeName): - """ Change the name of an attribute in place - - """ + + def renameAttribute(self, curAttributeName, newAttributeName): + """Change the name of an attribute in place -""" try: - i=self._attributeNameList.index(curAttributeName) - self._attributeNameList[i]=newAttributeName - del self._catalog[curAttributeName.lower()] - self._catalog[newAttributeName.lower()]=newAttributeName + i = self._attributeNameList.index(curAttributeName) + self._attributeNameList[i] = newAttributeName + del self._catalog[curAttributeName.lower()] + self._catalog[newAttributeName.lower()] = newAttributeName return True except: return False - - def printIt(self,fh=sys.stdout): + + def printIt(self, fh=sys.stdout): fh.write("--------------------------------------------\n") - fh.write(" Category: %s attribute list length: %d\n" % - (self._name,len(self._attributeNameList))) + fh.write(" Category: %s attribute list length: %d\n" % (self._name, len(self._attributeNameList))) for at in self._attributeNameList: - fh.write(" Category: %s attribute: %s\n" % (self._name,at)) - + fh.write(f" Category: {self._name} attribute: {at}\n") + fh.write(" Row value list length: %d\n" % len(self._rowList)) # for row in self._rowList[:2]: # if len(row) == len(self._attributeNameList): - for ii,v in enumerate(row): - fh.write(" %30s: %s ...\n" % (self._attributeNameList[ii],str(v)[:30])) + for ii, v in enumerate(row): + fh.write(" %30s: %s ...\n" % (self._attributeNameList[ii], str(v)[:30])) else: - fh.write("+WARNING - %s data length %d attribute name length %s mismatched\n" % - (self._name,len(row),len(self._attributeNameList))) + fh.write( + "+WARNING - %s data length %d attribute name length %s mismatched\n" + % (self._name, len(row), len(self._attributeNameList)) + ) - def dumpIt(self,fh=sys.stdout): + def dumpIt(self, fh=sys.stdout): fh.write("--------------------------------------------\n") - fh.write(" Category: %s attribute list length: %d\n" % - (self._name,len(self._attributeNameList))) + fh.write(" Category: %s attribute list length: %d\n" % (self._name, len(self._attributeNameList))) for at in self._attributeNameList: - fh.write(" Category: %s attribute: %s\n" % (self._name,at)) - + fh.write(f" Category: {self._name} attribute: {at}\n") + fh.write(" Value list length: %d\n" % len(self._rowList)) for row in self._rowList: - for ii,v in enumerate(row): - fh.write(" %30s: %s\n" % (self._attributeNameList[ii],v)) - + for ii, v in enumerate(row): + fh.write(" %30s: %s\n" % (self._attributeNameList[ii], v)) def __formatPdbx(self, inp): - """ Format input data following PDBx quoting rules - - """ + """Format input data following PDBx quoting rules -""" try: - if (inp is None): - return ("?",'DT_NULL_VALUE') + if inp is None: + return ("?", "DT_NULL_VALUE") # pure numerical values are returned as unquoted strings - if (isinstance(inp,int) or self.__intRe.search(str(inp))): - return ( [str(inp)],'DT_INTEGER') + if isinstance(inp, int) or self.__intRe.search(str(inp)): + return ([str(inp)], "DT_INTEGER") - if (isinstance(inp,float) or self.__floatRe.search(str(inp))): - return ([str(inp)],'DT_FLOAT') + if isinstance(inp, float) or self.__floatRe.search(str(inp)): + return ([str(inp)], "DT_FLOAT") # null value handling - - if (inp == "." or inp == "?"): - return ([inp],'DT_NULL_VALUE') + if inp == "." or inp == "?": + return ([inp], "DT_NULL_VALUE") - if (inp == ""): - return (["."],'DT_NULL_VALUE') + if inp == "": + return (["."], "DT_NULL_VALUE") # Contains white space or quotes ? if not self.__wsAndQuotesRe.search(inp): if inp.startswith("_"): - return (self.__doubleQuotedList(inp),'DT_ITEM_NAME') + return (self.__doubleQuotedList(inp), "DT_ITEM_NAME") else: - return ([str(inp)],'DT_UNQUOTED_STRING') + return ([str(inp)], "DT_UNQUOTED_STRING") else: if self.__nlRe.search(inp): - return (self.__semiColonQuotedList(inp),'DT_MULTI_LINE_STRING') + return (self.__semiColonQuotedList(inp), "DT_MULTI_LINE_STRING") else: - if (self.__avoidEmbeddedQuoting): + if self.__avoidEmbeddedQuoting: # change priority to choose double quoting where possible. if not self.__dqRe.search(inp) and not self.__sqWsRe.search(inp): - return (self.__doubleQuotedList(inp),'DT_DOUBLE_QUOTED_STRING') + return ( + self.__doubleQuotedList(inp), + "DT_DOUBLE_QUOTED_STRING", + ) elif not self.__sqRe.search(inp) and not self.__dqWsRe.search(inp): - return (self.__singleQuotedList(inp),'DT_SINGLE_QUOTED_STRING') + return ( + self.__singleQuotedList(inp), + "DT_SINGLE_QUOTED_STRING", + ) else: - return (self.__semiColonQuotedList(inp),'DT_MULTI_LINE_STRING') + return ( + self.__semiColonQuotedList(inp), + "DT_MULTI_LINE_STRING", + ) else: # change priority to choose double quoting where possible. if not self.__dqRe.search(inp): - return (self.__doubleQuotedList(inp),'DT_DOUBLE_QUOTED_STRING') + return ( + self.__doubleQuotedList(inp), + "DT_DOUBLE_QUOTED_STRING", + ) elif not self.__sqRe.search(inp): - return (self.__singleQuotedList(inp),'DT_SINGLE_QUOTED_STRING') + return ( + self.__singleQuotedList(inp), + "DT_SINGLE_QUOTED_STRING", + ) else: - return (self.__semiColonQuotedList(inp),'DT_MULTI_LINE_STRING') - - + return ( + self.__semiColonQuotedList(inp), + "DT_MULTI_LINE_STRING", + ) + except: - traceback.print_exc(file=self.__lfh) + traceback.print_exc(file=self.__lfh) def __dataTypePdbx(self, inp): - """ Detect the PDBx data type - - """ - if (inp is None): - return ('DT_NULL_VALUE') - + """Detect the PDBx data type -""" + if inp is None: + return "DT_NULL_VALUE" + # pure numerical values are returned as unquoted strings - if isinstance(inp,int) or self.__intRe.search(str(inp)): - return ('DT_INTEGER') + if isinstance(inp, int) or self.__intRe.search(str(inp)): + return "DT_INTEGER" - if isinstance(inp,float) or self.__floatRe.search(str(inp)): - return ('DT_FLOAT') + if isinstance(inp, float) or self.__floatRe.search(str(inp)): + return "DT_FLOAT" # null value handling - - if (inp == "." or inp == "?"): - return ('DT_NULL_VALUE') + if inp == "." or inp == "?": + return "DT_NULL_VALUE" - if (inp == ""): - return ('DT_NULL_VALUE') + if inp == "": + return "DT_NULL_VALUE" # Contains white space or quotes ? if not self.__wsAndQuotesRe.search(inp): if inp.startswith("_"): - return ('DT_ITEM_NAME') + return "DT_ITEM_NAME" else: - return ('DT_UNQUOTED_STRING') + return "DT_UNQUOTED_STRING" else: if self.__nlRe.search(inp): - return ('DT_MULTI_LINE_STRING') + return "DT_MULTI_LINE_STRING" else: - if (self.__avoidEmbeddedQuoting): + if self.__avoidEmbeddedQuoting: if not self.__sqRe.search(inp) and not self.__dqWsRe.search(inp): - return ('DT_DOUBLE_QUOTED_STRING') + return "DT_DOUBLE_QUOTED_STRING" elif not self.__dqRe.search(inp) and not self.__sqWsRe.search(inp): - return ('DT_SINGLE_QUOTED_STRING') + return "DT_SINGLE_QUOTED_STRING" else: - return ('DT_MULTI_LINE_STRING') + return "DT_MULTI_LINE_STRING" else: if not self.__sqRe.search(inp): - return ('DT_DOUBLE_QUOTED_STRING') + return "DT_DOUBLE_QUOTED_STRING" elif not self.__dqRe.search(inp): - return ('DT_SINGLE_QUOTED_STRING') + return "DT_SINGLE_QUOTED_STRING" else: - return ('DT_MULTI_LINE_STRING') + return "DT_MULTI_LINE_STRING" - def __singleQuotedList(self,inp): - l=[] + def __singleQuotedList(self, inp): + l = [] l.append("'") l.append(inp) - l.append("'") - return(l) + l.append("'") + return l - def __doubleQuotedList(self,inp): - l=[] + def __doubleQuotedList(self, inp): + l = [] l.append('"') l.append(inp) - l.append('"') - return(l) - - def __semiColonQuotedList(self,inp): - l=[] - l.append("\n") - if inp[-1] == '\n': + l.append('"') + return l + + def __semiColonQuotedList(self, inp): + l = [] + l.append("\n") + if inp[-1] == "\n": l.append(";") l.append(inp) l.append(";") - l.append("\n") + l.append("\n") else: l.append(";") l.append(inp) - l.append("\n") + l.append("\n") l.append(";") - l.append("\n") + l.append("\n") - return(l) + return l - def getValueFormatted(self,attributeName=None,rowIndex=None): + def getValueFormatted(self, attributeName=None, rowIndex=None): if attributeName is None: - attribute=self.__currentAttribute + attribute = self.__currentAttribute else: - attribute=attributeName + attribute = attributeName if rowIndex is None: rowI = self.__currentRowIndex else: rowI = rowIndex - - if isinstance(attribute, str) and isinstance(rowI,int): + + if isinstance(attribute, str) and isinstance(rowI, int): try: - list,type=self.__formatPdbx(self._rowList[rowI][self._attributeNameList.index(attribute)]) + list, type = self.__formatPdbx(self._rowList[rowI][self._attributeNameList.index(attribute)]) return "".join(list) - except (IndexError): - self.__lfh.write("attributeName %s rowI %r rowdata %r\n" % (attributeName,rowI,self._rowList[rowI])) - raise IndexError + except IndexError: + self.__lfh.write(f"attributeName {attributeName} rowI {rowI!r} rowdata {self._rowList[rowI]!r}\n") + raise IndexError raise TypeError(attribute) - - def getValueFormattedByIndex(self,attributeIndex,rowIndex): + def getValueFormattedByIndex(self, attributeIndex, rowIndex): try: - list,type=self.__formatPdbx(self._rowList[rowIndex][attributeIndex]) + list, type = self.__formatPdbx(self._rowList[rowIndex][attributeIndex]) return "".join(list) - except (IndexError): - raise IndexError + except IndexError: + raise IndexError - def getAttributeValueMaxLengthList(self,steps=1): - mList=[0 for i in range(len(self._attributeNameList))] + def getAttributeValueMaxLengthList(self, steps=1): + mList = [0 for i in range(len(self._attributeNameList))] for row in self._rowList[::steps]: for indx in range(len(self._attributeNameList)): - val=row[indx] - mList[indx] = max(mList[indx],len(str(val))) + val = row[indx] + mList[indx] = max(mList[indx], len(str(val))) return mList - def getFormatTypeList(self,steps=1): + def getFormatTypeList(self, steps=1): try: - curDataTypeList=['DT_NULL_VALUE' for i in range(len(self._attributeNameList))] + curDataTypeList = ["DT_NULL_VALUE" for i in range(len(self._attributeNameList))] for row in self._rowList[::steps]: for indx in range(len(self._attributeNameList)): - val=row[indx] + val = row[indx] # print "index ",indx," val ",val - dType=self.__dataTypePdbx(val) - dIndx=self.__dataTypeList.index(dType) + dType = self.__dataTypePdbx(val) + dIndx = self.__dataTypeList.index(dType) # print "d type", dType, " d type index ",dIndx - - cType=curDataTypeList[indx] - cIndx=self.__dataTypeList.index(cType) - cIndx= max(cIndx,dIndx) - curDataTypeList[indx]=self.__dataTypeList[cIndx] + + cType = curDataTypeList[indx] + cIndx = self.__dataTypeList.index(cType) + cIndx = max(cIndx, dIndx) + curDataTypeList[indx] = self.__dataTypeList[cIndx] # Map the format types to the data types - curFormatTypeList=[] + curFormatTypeList = [] for dt in curDataTypeList: - ii=self.__dataTypeList.index(dt) + ii = self.__dataTypeList.index(dt) curFormatTypeList.append(self.__formatTypeList[ii]) except: - self.__lfh.write("PdbxDataCategory(getFormatTypeList) ++Index error at index %d in row %r\n" % (indx,row)) + self.__lfh.write("PdbxDataCategory(getFormatTypeList) ++Index error at index %d in row %r\n" % (indx, row)) - return curFormatTypeList,curDataTypeList + return curFormatTypeList, curDataTypeList def getFormatTypeListX(self): - curDataTypeList=['DT_NULL_VALUE' for i in range(len(self._attributeNameList))] + curDataTypeList = ["DT_NULL_VALUE" for i in range(len(self._attributeNameList))] for row in self._rowList: for indx in range(len(self._attributeNameList)): - val=row[indx] - #print "index ",indx," val ",val - dType=self.__dataTypePdbx(val) - dIndx=self.__dataTypeList.index(dType) - #print "d type", dType, " d type index ",dIndx - - cType=curDataTypeList[indx] - cIndx=self.__dataTypeList.index(cType) - cIndx= max(cIndx,dIndx) - curDataTypeList[indx]=self.__dataTypeList[cIndx] + val = row[indx] + # print "index ",indx," val ",val + dType = self.__dataTypePdbx(val) + dIndx = self.__dataTypeList.index(dType) + # print "d type", dType, " d type index ",dIndx + + cType = curDataTypeList[indx] + cIndx = self.__dataTypeList.index(cType) + cIndx = max(cIndx, dIndx) + curDataTypeList[indx] = self.__dataTypeList[cIndx] # Map the format types to the data types - curFormatTypeList=[] + curFormatTypeList = [] for dt in curDataTypeList: - ii=self.__dataTypeList.index(dt) + ii = self.__dataTypeList.index(dt) curFormatTypeList.append(self.__formatTypeList[ii]) - return curFormatTypeList,curDataTypeList - - + return curFormatTypeList, curDataTypeList diff --git a/gufe/vendor/pdb_file/PdbxReader.py b/gufe/vendor/pdb_file/PdbxReader.py index 8b49dfc9..96f161e0 100644 --- a/gufe/vendor/pdb_file/PdbxReader.py +++ b/gufe/vendor/pdb_file/PdbxReader.py @@ -16,25 +16,27 @@ The tokenizer used in this module is modeled after the clever parser design used in the PyMMLIB package. - + PyMMLib Development Group Authors: Ethan Merritt: merritt@u.washington.ed & Jay Painter: jay.painter@gmail.com See: http://pymmlib.sourceforge.net/ """ -from __future__ import absolute_import import re + from .PdbxContainers import * + class PdbxError(Exception): - """ Class for catch general errors - """ + """Class for catch general errors""" + pass + class SyntaxError(Exception): - """ Class for catching syntax errors - """ + """Class for catching syntax errors""" + def __init__(self, lineNumber, text): Exception.__init__(self) self.lineNumber = lineNumber @@ -44,27 +46,26 @@ def __str__(self): return "%%ERROR - [at line: %d] %s" % (self.lineNumber, self.text) - -class PdbxReader(object): - """ PDBx reader for data files and dictionaries. - - """ - def __init__(self,ifh): - """ ifh - input file handle returned by open() - """ - # - self.__curLineNumber = 0 - self.__ifh=ifh - self.__stateDict={"data": "ST_DATA_CONTAINER", - "loop": "ST_TABLE", - "global": "ST_GLOBAL_CONTAINER", - "save": "ST_DEFINITION", - "stop": "ST_STOP"} - +class PdbxReader: + """PDBx reader for data files and dictionaries.""" + + def __init__(self, ifh): + """ifh - input file handle returned by open()""" + # + self.__curLineNumber = 0 + self.__ifh = ifh + self.__stateDict = { + "data": "ST_DATA_CONTAINER", + "loop": "ST_TABLE", + "global": "ST_GLOBAL_CONTAINER", + "save": "ST_DEFINITION", + "stop": "ST_STOP", + } + def read(self, containerList): """ Appends to the input list of definition and data containers. - + """ self.__curLineNumber = 0 try: @@ -72,7 +73,7 @@ def read(self, containerList): except StopIteration: pass except RuntimeError as err: - if 'StopIteration' not in str(err): + if "StopIteration" not in str(err): raise else: raise PdbxError() @@ -80,52 +81,51 @@ def read(self, containerList): def __syntaxError(self, errText): raise SyntaxError(self.__curLineNumber, errText) - def __getContainerName(self,inWord): - """ Returns the name of the data_ or save_ container - """ + def __getContainerName(self, inWord): + """Returns the name of the data_ or save_ container""" return str(inWord[5:]).strip() - + def __getState(self, inWord): - """Identifies reserved syntax elements and assigns an associated state. + """Identifies reserved syntax elements and assigns an associated state. - Returns: (reserved word, state) - where - - reserved word - is one of CIF syntax elements: - data_, loop_, global_, save_, stop_ - state - the parser state required to process this next section. + Returns: (reserved word, state) + where - + reserved word - is one of CIF syntax elements: + data_, loop_, global_, save_, stop_ + state - the parser state required to process this next section. """ i = inWord.find("_") if i == -1: - return None,"ST_UNKNOWN" + return None, "ST_UNKNOWN" try: - rWord=inWord[:i].lower() + rWord = inWord[:i].lower() return rWord, self.__stateDict[rWord] except: - return None,"ST_UNKNOWN" - + return None, "ST_UNKNOWN" + def __parser(self, tokenizer, containerList): - """ Parser for PDBx data files and dictionaries. + """Parser for PDBx data files and dictionaries. - Input - tokenizer() reentrant method recognizing data item names (_category.attribute) - quoted strings (single, double and multi-line semi-colon delimited), and unquoted - strings. + Input - tokenizer() reentrant method recognizing data item names (_category.attribute) + quoted strings (single, double and multi-line semi-colon delimited), and unquoted + strings. - containerList - list-type container for data and definition objects parsed from - from the input file. + containerList - list-type container for data and definition objects parsed from + from the input file. - Return: - containerList - is appended with data and definition objects - + Return: + containerList - is appended with data and definition objects - """ # Working container - data or definition curContainer = None # - # Working category container + # Working category container categoryIndex = {} curCategory = None # curRow = None - state = None + state = None # Find the first reserved word and begin capturing data. # @@ -133,16 +133,16 @@ def __parser(self, tokenizer, containerList): curCatName, curAttName, curQuotedString, curWord = next(tokenizer) if curWord is None: continue - reservedWord, state = self.__getState(curWord) + reservedWord, state = self.__getState(curWord) if reservedWord is not None: break - + while True: # # Set the current state - # # At this point in the processing cycle we are expecting a token containing - # either a '_category.attribute' or a reserved word. + # either a '_category.attribute' or a reserved word. # if curCatName is not None: state = "ST_KEY_VALUE_PAIR" @@ -150,16 +150,16 @@ def __parser(self, tokenizer, containerList): reservedWord, state = self.__getState(curWord) else: self.__syntaxError("Miscellaneous syntax error") - return + return # - # Process _category.attribute value assignments + # Process _category.attribute value assignments # if state == "ST_KEY_VALUE_PAIR": try: curCategory = categoryIndex[curCatName] except KeyError: - # A new category is encountered - create a container and add a row + # A new category is encountered - create a container and add a row curCategory = categoryIndex[curCatName] = DataCategory(curCatName) try: @@ -168,12 +168,12 @@ def __parser(self, tokenizer, containerList): self.__syntaxError("Category cannot be added to data_ block") return - curRow = [] + curRow = [] curCategory.append(curRow) else: # Recover the existing row from the category try: - curRow = curCategory[0] + curRow = curCategory[0] except IndexError: self.__syntaxError("Internal index error accessing category data") return @@ -185,18 +185,17 @@ def __parser(self, tokenizer, containerList): else: curCategory.appendAttribute(curAttName) - # Get the data for this attribute from the next token tCat, tAtt, curQuotedString, curWord = next(tokenizer) if tCat is not None or (curQuotedString is None and curWord is None): - self.__syntaxError("Missing data for item _%s.%s" % (curCatName,curAttName)) + self.__syntaxError(f"Missing data for item _{curCatName}.{curAttName}") if curWord is not None: - # - # Validation check token for misplaced reserved words - # - reservedWord, state = self.__getState(curWord) + # Validation check token for misplaced reserved words - + # + reservedWord, state = self.__getState(curWord) if reservedWord is not None: self.__syntaxError("Unexpected reserved word: %s" % (reservedWord)) @@ -218,7 +217,7 @@ def __parser(self, tokenizer, containerList): # The category name in the next curCatName,curAttName pair # defines the name of the category container. - curCatName,curAttName,curQuotedString,curWord = next(tokenizer) + curCatName, curAttName, curQuotedString, curWord = next(tokenizer) if curCatName is None or curAttName is None: self.__syntaxError("Unexpected token in loop_ declaration") @@ -239,10 +238,10 @@ def __parser(self, tokenizer, containerList): curCategory.appendAttribute(curAttName) - # Read the rest of the loop_ declaration + # Read the rest of the loop_ declaration while True: curCatName, curAttName, curQuotedString, curWord = next(tokenizer) - + if curCatName is None: break @@ -252,19 +251,18 @@ def __parser(self, tokenizer, containerList): curCategory.appendAttribute(curAttName) - - # If the next token is a 'word', check it for any reserved words - + # If the next token is a 'word', check it for any reserved words - if curWord is not None: - reservedWord, state = self.__getState(curWord) + reservedWord, state = self.__getState(curWord) if reservedWord is not None: if reservedWord == "stop": return else: self.__syntaxError("Unexpected reserved word after loop declaration: %s" % (reservedWord)) - - # Read the table of data for this loop_ - + + # Read the table of data for this loop_ - while True: - curRow = [] + curRow = [] curCategory.append(curRow) for tAtt in curCategory.getAttributeList(): @@ -273,9 +271,9 @@ def __parser(self, tokenizer, containerList): elif curQuotedString is not None: curRow.append(curQuotedString) - curCatName,curAttName,curQuotedString,curWord = next(tokenizer) + curCatName, curAttName, curQuotedString, curWord = next(tokenizer) - # loop_ data processing ends if - + # loop_ data processing ends if - # A new _category.attribute is encountered if curCatName is not None: @@ -286,31 +284,30 @@ def __parser(self, tokenizer, containerList): reservedWord, state = self.__getState(curWord) if reservedWord is not None: break - - continue + continue elif state == "ST_DEFINITION": # Ignore trailing unnamed saveframe delimiters e.g. 'save_' - sName=self.__getContainerName(curWord) - if (len(sName) > 0): + sName = self.__getContainerName(curWord) + if len(sName) > 0: curContainer = DefinitionContainer(sName) containerList.append(curContainer) categoryIndex = {} curCategory = None - curCatName,curAttName,curQuotedString,curWord = next(tokenizer) + curCatName, curAttName, curQuotedString, curWord = next(tokenizer) elif state == "ST_DATA_CONTAINER": # - dName=self.__getContainerName(curWord) + dName = self.__getContainerName(curWord) if len(dName) == 0: - dName="unidentified" + dName = "unidentified" curContainer = DataContainer(dName) containerList.append(curContainer) categoryIndex = {} curCategory = None - curCatName,curAttName,curQuotedString,curWord = next(tokenizer) + curCatName, curAttName, curQuotedString, curWord = next(tokenizer) elif state == "ST_STOP": return @@ -320,22 +317,21 @@ def __parser(self, tokenizer, containerList): containerList.append(curContainer) categoryIndex = {} curCategory = None - curCatName,curAttName,curQuotedString,curWord = next(tokenizer) + curCatName, curAttName, curQuotedString, curWord = next(tokenizer) elif state == "ST_UNKNOWN": self.__syntaxError("Unrecogized syntax element: " + str(curWord)) return - def __tokenizer(self, ifh): - """ Tokenizer method for the mmCIF syntax file - + """Tokenizer method for the mmCIF syntax file - - Each return/yield from this method returns information about - the next token in the form of a tuple with the following structure. + Each return/yield from this method returns information about + the next token in the form of a tuple with the following structure. - (category name, attribute name, quoted strings, words w/o quotes or white space) + (category name, attribute name, quoted strings, words w/o quotes or white space) - Differentiated the reqular expression to the better handle embedded quotes. + Differentiated the reqular expression to the better handle embedded quotes. """ # @@ -343,17 +339,17 @@ def __tokenizer(self, ifh): # outside of this regex. mmcifRe = re.compile( r"(?:" - - "(?:_(.+?)[.](\S+))" "|" # _category.attribute - - "(?:['](.*?)(?:[']\s|[']$))" "|" # single quoted strings - "(?:[\"](.*?)(?:[\"]\s|[\"]$))" "|" # double quoted strings - - "(?:\s*#.*$)" "|" # comments (dumped) - - "(\S+)" # unquoted words - - ")") + r"(?:_(.+?)[.](\S+))" + "|" # _category.attribute + r"(?:['](.*?)(?:[']\s|[']$))" + "|" # single quoted strings + r'(?:["](.*?)(?:["]\s|["]$))' + "|" # double quoted strings + r"(?:\s*#.*$)" + "|" # comments (dumped) + r"(\S+)" # unquoted words + ")" + ) fileIter = iter(ifh) @@ -365,7 +361,7 @@ def __tokenizer(self, ifh): # Dump comments if line.startswith("#"): continue - + # Gobble up the entire semi-colon/multi-line delimited string and # and stuff this into the string slot in the return tuple # @@ -385,7 +381,7 @@ def __tokenizer(self, ifh): # # Need to process the remainder of the current line - line = line[1:] - #continue + # continue # Apply regex to the current line consolidate the single/double # quoted within the quoted string category @@ -398,16 +394,16 @@ def __tokenizer(self, ifh): qs = tgroups[3] else: qs = None - groups = (tgroups[0],tgroups[1],qs,tgroups[4]) + groups = (tgroups[0], tgroups[1], qs, tgroups[4]) yield groups def __tokenizerOrg(self, ifh): - """ Tokenizer method for the mmCIF syntax file - + """Tokenizer method for the mmCIF syntax file - - Each return/yield from this method returns information about - the next token in the form of a tuple with the following structure. + Each return/yield from this method returns information about + the next token in the form of a tuple with the following structure. - (category name, attribute name, quoted strings, words w/o quotes or white space) + (category name, attribute name, quoted strings, words w/o quotes or white space) """ # @@ -415,16 +411,15 @@ def __tokenizerOrg(self, ifh): # outside of this regex. mmcifRe = re.compile( r"(?:" - - "(?:_(.+?)[.](\S+))" "|" # _category.attribute - - "(?:['\"](.*?)(?:['\"]\s|['\"]$))" "|" # quoted strings - - "(?:\s*#.*$)" "|" # comments (dumped) - - "(\S+)" # unquoted words - - ")") + r"(?:_(.+?)[.](\S+))" + "|" # _category.attribute + "(?:['\"](.*?)(?:['\"]\\s|['\"]$))" + "|" # quoted strings + r"(?:\s*#.*$)" + "|" # comments (dumped) + r"(\S+)" # unquoted words + ")" + ) fileIter = iter(ifh) @@ -436,7 +431,7 @@ def __tokenizerOrg(self, ifh): # Dump comments if line.startswith("#"): continue - + # Gobble up the entire semi-colon/multi-line delimited string and # and stuff this into the string slot in the return tuple # @@ -456,9 +451,9 @@ def __tokenizerOrg(self, ifh): # # Need to process the remainder of the current line - line = line[1:] - #continue + # continue - ## Apply regex to the current line + ## Apply regex to the current line for it in mmcifRe.finditer(line): groups = it.groups() if groups != (None, None, None, None): diff --git a/gufe/vendor/pdb_file/data/residues.xml b/gufe/vendor/pdb_file/data/residues.xml index 7b6cfdd1..2a353a33 100644 --- a/gufe/vendor/pdb_file/data/residues.xml +++ b/gufe/vendor/pdb_file/data/residues.xml @@ -899,4 +899,4 @@ - \ No newline at end of file + diff --git a/gufe/vendor/pdb_file/element.py b/gufe/vendor/pdb_file/element.py index 6d00067c..3b675233 100644 --- a/gufe/vendor/pdb_file/element.py +++ b/gufe/vendor/pdb_file/element.py @@ -28,21 +28,18 @@ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from __future__ import absolute_import + __author__ = "Christopher M. Bruns" __version__ = "1.0" +import copyreg import sys from collections import OrderedDict -from openmm.unit import daltons, is_quantity -if sys.version_info >= (3, 0): - import copyreg -else: - import copy_reg as copyreg +from openmm.unit import daltons, is_quantity -class Element(object): +class Element: """An Element represents a chemical element. The openmm.app.element module contains objects for all the standard chemical elements, @@ -84,7 +81,7 @@ def __init__(self, number, name, symbol, mass): Element._elements_by_mass = None if s in Element._elements_by_symbol: - raise ValueError('Duplicate element symbol %s' % s) + raise ValueError("Duplicate element symbol %s" % s) Element._elements_by_symbol[s] = self if number in Element._elements_by_atomic_number: other_element = Element._elements_by_atomic_number[number] @@ -128,13 +125,12 @@ def getByMass(mass): if is_quantity(mass): mass = mass.value_in_unit(daltons) if mass < 0: - raise ValueError('Invalid Higgs field') + raise ValueError("Invalid Higgs field") # If this is our first time calling getByMass (or we added an element # since the last call), re-generate the ordered by-mass dict cache if Element._elements_by_mass is None: Element._elements_by_mass = OrderedDict() - for elem in sorted(Element._elements_by_symbol.values(), - key=lambda x: x.mass): + for elem in sorted(Element._elements_by_symbol.values(), key=lambda x: x.mass): Element._elements_by_mass[elem.mass.value_in_unit(daltons)] = elem diff = mass @@ -170,151 +166,151 @@ def mass(self): return self._mass def __str__(self): - return '' % self.name + return "" % self.name def __repr__(self): - return '' % self.name + return "" % self.name + # This is for backward compatibility. def get_by_symbol(symbol): - """ Get the element with a particular chemical symbol. """ + """Get the element with a particular chemical symbol.""" s = symbol.strip().upper() return Element._elements_by_symbol[s] + def _pickle_element(element): return (get_by_symbol, (element.symbol,)) + copyreg.pickle(Element, _pickle_element) # NOTE: getElementByMass assumes all masses are Quantity instances with unit # "daltons". All elements need to obey this assumption, or that method will # fail. No checking is done in getElementByMass for performance reasons -hydrogen = Element( 1, "hydrogen", "H", 1.007947*daltons) -deuterium = Element( 1, "deuterium", "D", 2.01355321270*daltons) -helium = Element( 2, "helium", "He", 4.003*daltons) -lithium = Element( 3, "lithium", "Li", 6.9412*daltons) -beryllium = Element( 4, "beryllium", "Be", 9.0121823*daltons) -boron = Element( 5, "boron", "B", 10.8117*daltons) -carbon = Element( 6, "carbon", "C", 12.01078*daltons) -nitrogen = Element( 7, "nitrogen", "N", 14.00672*daltons) -oxygen = Element( 8, "oxygen", "O", 15.99943*daltons) -fluorine = Element( 9, "fluorine", "F", 18.99840325*daltons) -neon = Element( 10, "neon", "Ne", 20.17976*daltons) -sodium = Element( 11, "sodium", "Na", 22.989769282*daltons) -magnesium = Element( 12, "magnesium", "Mg", 24.30506*daltons) -aluminum = Element( 13, "aluminum", "Al", 26.98153868*daltons) -silicon = Element( 14, "silicon", "Si", 28.08553*daltons) -phosphorus = Element( 15, "phosphorus", "P", 30.9737622*daltons) -sulfur = Element( 16, "sulfur", "S", 32.0655*daltons) -chlorine = Element( 17, "chlorine", "Cl", 35.4532*daltons) -argon = Element( 18, "argon", "Ar", 39.9481*daltons) -potassium = Element( 19, "potassium", "K", 39.09831*daltons) -calcium = Element( 20, "calcium", "Ca", 40.0784*daltons) -scandium = Element( 21, "scandium", "Sc", 44.9559126*daltons) -titanium = Element( 22, "titanium", "Ti", 47.8671*daltons) -vanadium = Element( 23, "vanadium", "V", 50.94151*daltons) -chromium = Element( 24, "chromium", "Cr", 51.99616*daltons) -manganese = Element( 25, "manganese", "Mn", 54.9380455*daltons) -iron = Element( 26, "iron", "Fe", 55.8452*daltons) -cobalt = Element( 27, "cobalt", "Co", 58.9331955*daltons) -nickel = Element( 28, "nickel", "Ni", 58.69342*daltons) -copper = Element( 29, "copper", "Cu", 63.5463*daltons) -zinc = Element( 30, "zinc", "Zn", 65.4094*daltons) -gallium = Element( 31, "gallium", "Ga", 69.7231*daltons) -germanium = Element( 32, "germanium", "Ge", 72.641*daltons) -arsenic = Element( 33, "arsenic", "As", 74.921602*daltons) -selenium = Element( 34, "selenium", "Se", 78.963*daltons) -bromine = Element( 35, "bromine", "Br", 79.9041*daltons) -krypton = Element( 36, "krypton", "Kr", 83.7982*daltons) -rubidium = Element( 37, "rubidium", "Rb", 85.46783*daltons) -strontium = Element( 38, "strontium", "Sr", 87.621*daltons) -yttrium = Element( 39, "yttrium", "Y", 88.905852*daltons) -zirconium = Element( 40, "zirconium", "Zr", 91.2242*daltons) -niobium = Element( 41, "niobium", "Nb", 92.906382*daltons) -molybdenum = Element( 42, "molybdenum", "Mo", 95.942*daltons) -technetium = Element( 43, "technetium", "Tc", 98*daltons) -ruthenium = Element( 44, "ruthenium", "Ru", 101.072*daltons) -rhodium = Element( 45, "rhodium", "Rh", 102.905502*daltons) -palladium = Element( 46, "palladium", "Pd", 106.421*daltons) -silver = Element( 47, "silver", "Ag", 107.86822*daltons) -cadmium = Element( 48, "cadmium", "Cd", 112.4118*daltons) -indium = Element( 49, "indium", "In", 114.8183*daltons) -tin = Element( 50, "tin", "Sn", 118.7107*daltons) -antimony = Element( 51, "antimony", "Sb", 121.7601*daltons) -tellurium = Element( 52, "tellurium", "Te", 127.603*daltons) -iodine = Element( 53, "iodine", "I", 126.904473*daltons) -xenon = Element( 54, "xenon", "Xe", 131.2936*daltons) -cesium = Element( 55, "cesium", "Cs", 132.90545192*daltons) -barium = Element( 56, "barium", "Ba", 137.3277*daltons) -lanthanum = Element( 57, "lanthanum", "La", 138.905477*daltons) -cerium = Element( 58, "cerium", "Ce", 140.1161*daltons) -praseodymium = Element( 59, "praseodymium", "Pr", 140.907652*daltons) -neodymium = Element( 60, "neodymium", "Nd", 144.2423*daltons) -promethium = Element( 61, "promethium", "Pm", 145*daltons) -samarium = Element( 62, "samarium", "Sm", 150.362*daltons) -europium = Element( 63, "europium", "Eu", 151.9641*daltons) -gadolinium = Element( 64, "gadolinium", "Gd", 157.253*daltons) -terbium = Element( 65, "terbium", "Tb", 158.925352*daltons) -dysprosium = Element( 66, "dysprosium", "Dy", 162.5001*daltons) -holmium = Element( 67, "holmium", "Ho", 164.930322*daltons) -erbium = Element( 68, "erbium", "Er", 167.2593*daltons) -thulium = Element( 69, "thulium", "Tm", 168.934212*daltons) -ytterbium = Element( 70, "ytterbium", "Yb", 173.043*daltons) -lutetium = Element( 71, "lutetium", "Lu", 174.9671*daltons) -hafnium = Element( 72, "hafnium", "Hf", 178.492*daltons) -tantalum = Element( 73, "tantalum", "Ta", 180.947882*daltons) -tungsten = Element( 74, "tungsten", "W", 183.841*daltons) -rhenium = Element( 75, "rhenium", "Re", 186.2071*daltons) -osmium = Element( 76, "osmium", "Os", 190.233*daltons) -iridium = Element( 77, "iridium", "Ir", 192.2173*daltons) -platinum = Element( 78, "platinum", "Pt", 195.0849*daltons) -gold = Element( 79, "gold", "Au", 196.9665694*daltons) -mercury = Element( 80, "mercury", "Hg", 200.592*daltons) -thallium = Element( 81, "thallium", "Tl", 204.38332*daltons) -lead = Element( 82, "lead", "Pb", 207.21*daltons) -bismuth = Element( 83, "bismuth", "Bi", 208.980401*daltons) -polonium = Element( 84, "polonium", "Po", 209*daltons) -astatine = Element( 85, "astatine", "At", 210*daltons) -radon = Element( 86, "radon", "Rn", 222.018*daltons) -francium = Element( 87, "francium", "Fr", 223*daltons) -radium = Element( 88, "radium", "Ra", 226*daltons) -actinium = Element( 89, "actinium", "Ac", 227*daltons) -thorium = Element( 90, "thorium", "Th", 232.038062*daltons) -protactinium = Element( 91, "protactinium", "Pa", 231.035882*daltons) -uranium = Element( 92, "uranium", "U", 238.028913*daltons) -neptunium = Element( 93, "neptunium", "Np", 237*daltons) -plutonium = Element( 94, "plutonium", "Pu", 244*daltons) -americium = Element( 95, "americium", "Am", 243*daltons) -curium = Element( 96, "curium", "Cm", 247*daltons) -berkelium = Element( 97, "berkelium", "Bk", 247*daltons) -californium = Element( 98, "californium", "Cf", 251*daltons) -einsteinium = Element( 99, "einsteinium", "Es", 252*daltons) -fermium = Element(100, "fermium", "Fm", 257*daltons) -mendelevium = Element(101, "mendelevium", "Md", 258*daltons) -nobelium = Element(102, "nobelium", "No", 259*daltons) -lawrencium = Element(103, "lawrencium", "Lr", 262*daltons) -rutherfordium = Element(104, "rutherfordium", "Rf", 261*daltons) -dubnium = Element(105, "dubnium", "Db", 262*daltons) -seaborgium = Element(106, "seaborgium", "Sg", 266*daltons) -bohrium = Element(107, "bohrium", "Bh", 264*daltons) -hassium = Element(108, "hassium", "Hs", 269*daltons) -meitnerium = Element(109, "meitnerium", "Mt", 268*daltons) -darmstadtium = Element(110, "darmstadtium", "Ds", 281*daltons) -roentgenium = Element(111, "roentgenium", "Rg", 272*daltons) -ununbium = Element(112, "ununbium", "Uub", 285*daltons) -ununtrium = Element(113, "ununtrium", "Uut", 284*daltons) -ununquadium = Element(114, "ununquadium", "Uuq", 289*daltons) -ununpentium = Element(115, "ununpentium", "Uup", 288*daltons) -ununhexium = Element(116, "ununhexium", "Uuh", 292*daltons) +hydrogen = Element(1, "hydrogen", "H", 1.007947 * daltons) +deuterium = Element(1, "deuterium", "D", 2.01355321270 * daltons) +helium = Element(2, "helium", "He", 4.003 * daltons) +lithium = Element(3, "lithium", "Li", 6.9412 * daltons) +beryllium = Element(4, "beryllium", "Be", 9.0121823 * daltons) +boron = Element(5, "boron", "B", 10.8117 * daltons) +carbon = Element(6, "carbon", "C", 12.01078 * daltons) +nitrogen = Element(7, "nitrogen", "N", 14.00672 * daltons) +oxygen = Element(8, "oxygen", "O", 15.99943 * daltons) +fluorine = Element(9, "fluorine", "F", 18.99840325 * daltons) +neon = Element(10, "neon", "Ne", 20.17976 * daltons) +sodium = Element(11, "sodium", "Na", 22.989769282 * daltons) +magnesium = Element(12, "magnesium", "Mg", 24.30506 * daltons) +aluminum = Element(13, "aluminum", "Al", 26.98153868 * daltons) +silicon = Element(14, "silicon", "Si", 28.08553 * daltons) +phosphorus = Element(15, "phosphorus", "P", 30.9737622 * daltons) +sulfur = Element(16, "sulfur", "S", 32.0655 * daltons) +chlorine = Element(17, "chlorine", "Cl", 35.4532 * daltons) +argon = Element(18, "argon", "Ar", 39.9481 * daltons) +potassium = Element(19, "potassium", "K", 39.09831 * daltons) +calcium = Element(20, "calcium", "Ca", 40.0784 * daltons) +scandium = Element(21, "scandium", "Sc", 44.9559126 * daltons) +titanium = Element(22, "titanium", "Ti", 47.8671 * daltons) +vanadium = Element(23, "vanadium", "V", 50.94151 * daltons) +chromium = Element(24, "chromium", "Cr", 51.99616 * daltons) +manganese = Element(25, "manganese", "Mn", 54.9380455 * daltons) +iron = Element(26, "iron", "Fe", 55.8452 * daltons) +cobalt = Element(27, "cobalt", "Co", 58.9331955 * daltons) +nickel = Element(28, "nickel", "Ni", 58.69342 * daltons) +copper = Element(29, "copper", "Cu", 63.5463 * daltons) +zinc = Element(30, "zinc", "Zn", 65.4094 * daltons) +gallium = Element(31, "gallium", "Ga", 69.7231 * daltons) +germanium = Element(32, "germanium", "Ge", 72.641 * daltons) +arsenic = Element(33, "arsenic", "As", 74.921602 * daltons) +selenium = Element(34, "selenium", "Se", 78.963 * daltons) +bromine = Element(35, "bromine", "Br", 79.9041 * daltons) +krypton = Element(36, "krypton", "Kr", 83.7982 * daltons) +rubidium = Element(37, "rubidium", "Rb", 85.46783 * daltons) +strontium = Element(38, "strontium", "Sr", 87.621 * daltons) +yttrium = Element(39, "yttrium", "Y", 88.905852 * daltons) +zirconium = Element(40, "zirconium", "Zr", 91.2242 * daltons) +niobium = Element(41, "niobium", "Nb", 92.906382 * daltons) +molybdenum = Element(42, "molybdenum", "Mo", 95.942 * daltons) +technetium = Element(43, "technetium", "Tc", 98 * daltons) +ruthenium = Element(44, "ruthenium", "Ru", 101.072 * daltons) +rhodium = Element(45, "rhodium", "Rh", 102.905502 * daltons) +palladium = Element(46, "palladium", "Pd", 106.421 * daltons) +silver = Element(47, "silver", "Ag", 107.86822 * daltons) +cadmium = Element(48, "cadmium", "Cd", 112.4118 * daltons) +indium = Element(49, "indium", "In", 114.8183 * daltons) +tin = Element(50, "tin", "Sn", 118.7107 * daltons) +antimony = Element(51, "antimony", "Sb", 121.7601 * daltons) +tellurium = Element(52, "tellurium", "Te", 127.603 * daltons) +iodine = Element(53, "iodine", "I", 126.904473 * daltons) +xenon = Element(54, "xenon", "Xe", 131.2936 * daltons) +cesium = Element(55, "cesium", "Cs", 132.90545192 * daltons) +barium = Element(56, "barium", "Ba", 137.3277 * daltons) +lanthanum = Element(57, "lanthanum", "La", 138.905477 * daltons) +cerium = Element(58, "cerium", "Ce", 140.1161 * daltons) +praseodymium = Element(59, "praseodymium", "Pr", 140.907652 * daltons) +neodymium = Element(60, "neodymium", "Nd", 144.2423 * daltons) +promethium = Element(61, "promethium", "Pm", 145 * daltons) +samarium = Element(62, "samarium", "Sm", 150.362 * daltons) +europium = Element(63, "europium", "Eu", 151.9641 * daltons) +gadolinium = Element(64, "gadolinium", "Gd", 157.253 * daltons) +terbium = Element(65, "terbium", "Tb", 158.925352 * daltons) +dysprosium = Element(66, "dysprosium", "Dy", 162.5001 * daltons) +holmium = Element(67, "holmium", "Ho", 164.930322 * daltons) +erbium = Element(68, "erbium", "Er", 167.2593 * daltons) +thulium = Element(69, "thulium", "Tm", 168.934212 * daltons) +ytterbium = Element(70, "ytterbium", "Yb", 173.043 * daltons) +lutetium = Element(71, "lutetium", "Lu", 174.9671 * daltons) +hafnium = Element(72, "hafnium", "Hf", 178.492 * daltons) +tantalum = Element(73, "tantalum", "Ta", 180.947882 * daltons) +tungsten = Element(74, "tungsten", "W", 183.841 * daltons) +rhenium = Element(75, "rhenium", "Re", 186.2071 * daltons) +osmium = Element(76, "osmium", "Os", 190.233 * daltons) +iridium = Element(77, "iridium", "Ir", 192.2173 * daltons) +platinum = Element(78, "platinum", "Pt", 195.0849 * daltons) +gold = Element(79, "gold", "Au", 196.9665694 * daltons) +mercury = Element(80, "mercury", "Hg", 200.592 * daltons) +thallium = Element(81, "thallium", "Tl", 204.38332 * daltons) +lead = Element(82, "lead", "Pb", 207.21 * daltons) +bismuth = Element(83, "bismuth", "Bi", 208.980401 * daltons) +polonium = Element(84, "polonium", "Po", 209 * daltons) +astatine = Element(85, "astatine", "At", 210 * daltons) +radon = Element(86, "radon", "Rn", 222.018 * daltons) +francium = Element(87, "francium", "Fr", 223 * daltons) +radium = Element(88, "radium", "Ra", 226 * daltons) +actinium = Element(89, "actinium", "Ac", 227 * daltons) +thorium = Element(90, "thorium", "Th", 232.038062 * daltons) +protactinium = Element(91, "protactinium", "Pa", 231.035882 * daltons) +uranium = Element(92, "uranium", "U", 238.028913 * daltons) +neptunium = Element(93, "neptunium", "Np", 237 * daltons) +plutonium = Element(94, "plutonium", "Pu", 244 * daltons) +americium = Element(95, "americium", "Am", 243 * daltons) +curium = Element(96, "curium", "Cm", 247 * daltons) +berkelium = Element(97, "berkelium", "Bk", 247 * daltons) +californium = Element(98, "californium", "Cf", 251 * daltons) +einsteinium = Element(99, "einsteinium", "Es", 252 * daltons) +fermium = Element(100, "fermium", "Fm", 257 * daltons) +mendelevium = Element(101, "mendelevium", "Md", 258 * daltons) +nobelium = Element(102, "nobelium", "No", 259 * daltons) +lawrencium = Element(103, "lawrencium", "Lr", 262 * daltons) +rutherfordium = Element(104, "rutherfordium", "Rf", 261 * daltons) +dubnium = Element(105, "dubnium", "Db", 262 * daltons) +seaborgium = Element(106, "seaborgium", "Sg", 266 * daltons) +bohrium = Element(107, "bohrium", "Bh", 264 * daltons) +hassium = Element(108, "hassium", "Hs", 269 * daltons) +meitnerium = Element(109, "meitnerium", "Mt", 268 * daltons) +darmstadtium = Element(110, "darmstadtium", "Ds", 281 * daltons) +roentgenium = Element(111, "roentgenium", "Rg", 272 * daltons) +ununbium = Element(112, "ununbium", "Uub", 285 * daltons) +ununtrium = Element(113, "ununtrium", "Uut", 284 * daltons) +ununquadium = Element(114, "ununquadium", "Uuq", 289 * daltons) +ununpentium = Element(115, "ununpentium", "Uup", 288 * daltons) +ununhexium = Element(116, "ununhexium", "Uuh", 292 * daltons) # Aliases to recognize common alternative spellings. Both the '==' and 'is' # relational operators will work with any chosen name sulphur = sulfur aluminium = aluminum -if sys.version_info >= (3, 0): - def _iteritems(dict): - return dict.items() -else: - def _iteritems(dict): - return dict.iteritems() + +def _iteritems(dict): + return dict.items() diff --git a/gufe/vendor/pdb_file/pdbfile.py b/gufe/vendor/pdb_file/pdbfile.py index df769055..9eb26111 100644 --- a/gufe/vendor/pdb_file/pdbfile.py +++ b/gufe/vendor/pdb_file/pdbfile.py @@ -28,40 +28,71 @@ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from __future__ import print_function, division, absolute_import + __author__ = "Peter Eastman" # was true but now riesben vendored this! HarrHarr __version__ = "1.0" +import math import os import sys -import math -import numpy as np import xml.etree.ElementTree as etree from copy import copy from datetime import date +import numpy as np +from openmm.unit import Quantity, angstroms, is_quantity, nanometers, norm + +from . import element as elem +from .pdbstructure import PdbStructure from .topology import Topology from .unitcell import computeLengthsAndAngles -from .pdbstructure import PdbStructure -from . import element as elem -from openmm.unit import nanometers, angstroms, is_quantity, norm, Quantity - -class PDBFile(object): +class PDBFile: """PDBFile parses a Protein Data Bank (PDB) file and constructs a Topology and a set of atom positions from it. This class also provides methods for creating PDB files. To write a file containing a single model, call writeFile(). You also can create files that contain multiple models. To do this, first call writeHeader(), - then writeModel() once for each model in the file, and finally writeFooter() to complete the file.""" + then writeModel() once for each model in the file, and finally writeFooter() to complete the file. + """ _residueNameReplacements = {} _atomNameReplacements = {} - _standardResidues = ['ALA', 'ASN', 'CYS', 'GLU', 'HIS', 'LEU', 'MET', 'PRO', 'THR', 'TYR', - 'ARG', 'ASP', 'GLN', 'GLY', 'ILE', 'LYS', 'PHE', 'SER', 'TRP', 'VAL', - 'A', 'G', 'C', 'U', 'I', 'DA', 'DG', 'DC', 'DT', 'DI', 'HOH'] - - def __init__(self, file, extraParticleIdentifier='EP'): + _standardResidues = [ + "ALA", + "ASN", + "CYS", + "GLU", + "HIS", + "LEU", + "MET", + "PRO", + "THR", + "TYR", + "ARG", + "ASP", + "GLN", + "GLY", + "ILE", + "LYS", + "PHE", + "SER", + "TRP", + "VAL", + "A", + "G", + "C", + "U", + "I", + "DA", + "DG", + "DC", + "DT", + "DI", + "HOH", + ] + + def __init__(self, file, extraParticleIdentifier="EP"): """Load a PDB file. The atom positions and Topology can be retrieved by calling getPositions() and getTopology(). @@ -73,10 +104,46 @@ def __init__(self, file, extraParticleIdentifier='EP'): extraParticleIdentifier : string='EP' if this value appears in the element column for an ATOM record, the Atom's element will be set to None to mark it as an extra particle """ - - metalElements = ['Al','As','Ba','Ca','Cd','Ce','Co','Cs','Cu','Dy','Fe','Gd','Hg','Ho','In','Ir','K','Li','Mg', - 'Mn','Mo','Na','Ni','Pb','Pd','Pt','Rb','Rh','Sm','Sr','Te','Tl','V','W','Yb','Zn'] - + + metalElements = [ + "Al", + "As", + "Ba", + "Ca", + "Cd", + "Ce", + "Co", + "Cs", + "Cu", + "Dy", + "Fe", + "Gd", + "Hg", + "Ho", + "In", + "Ir", + "K", + "Li", + "Mg", + "Mn", + "Mo", + "Na", + "Ni", + "Pb", + "Pd", + "Pt", + "Rb", + "Rh", + "Sm", + "Sr", + "Te", + "Tl", + "V", + "W", + "Yb", + "Zn", + ] + top = Topology() ## The Topology read from the PDB file self.topology = top @@ -91,7 +158,11 @@ def __init__(self, file, extraParticleIdentifier='EP'): if isinstance(file, str): inputfile = open(file) own_handle = True - pdb = PdbStructure(inputfile, load_all_models=True, extraParticleIdentifier=extraParticleIdentifier) + pdb = PdbStructure( + inputfile, + load_all_models=True, + extraParticleIdentifier=extraParticleIdentifier, + ) if own_handle: inputfile.close() PDBFile._loadNameReplacementTables() @@ -120,7 +191,7 @@ def __init__(self, file, extraParticleIdentifier='EP'): atomName = atomReplacements[atomName] atomName = atomName.strip() element = atom.element - if element == 'EP': + if element == "EP": element = None elif element is None: # Try to guess the element. @@ -128,24 +199,24 @@ def __init__(self, file, extraParticleIdentifier='EP'): upper = atomName.upper() while len(upper) > 1 and upper[0].isdigit(): upper = upper[1:] - if upper.startswith('CL'): + if upper.startswith("CL"): element = elem.chlorine - elif upper.startswith('NA'): + elif upper.startswith("NA"): element = elem.sodium - elif upper.startswith('MG'): + elif upper.startswith("MG"): element = elem.magnesium - elif upper.startswith('BE'): + elif upper.startswith("BE"): element = elem.beryllium - elif upper.startswith('LI'): + elif upper.startswith("LI"): element = elem.lithium - elif upper.startswith('K'): + elif upper.startswith("K"): element = elem.potassium - elif upper.startswith('ZN'): + elif upper.startswith("ZN"): element = elem.zinc - elif len(residue) == 1 and upper.startswith('CA'): + elif len(residue) == 1 and upper.startswith("CA"): element = elem.calcium - elif upper.startswith('D') and any(a.name == atomName[1:] for a in residue.iter_atoms()): - pass # A Drude particle + elif upper.startswith("D") and any(a.name == atomName[1:] for a in residue.iter_atoms()): + pass # A Drude particle else: try: element = elem.get_by_symbol(upper[0]) @@ -165,7 +236,7 @@ def __init__(self, file, extraParticleIdentifier='EP'): processedAtomNames.add(atom.get_name()) pos = atom.get_position().value_in_unit(nanometers) coords.append(np.array([pos[0], pos[1], pos[2]])) - self._positions.append(coords*nanometers) + self._positions.append(coords * nanometers) ## The atom positions read from the PDB file. If the file contains multiple frames, these are the positions in the first frame. self.positions = self._positions[0] self.topology.setPeriodicBoxVectors(pdb.get_periodic_box_vectors()) @@ -179,16 +250,25 @@ def __init__(self, file, extraParticleIdentifier='EP'): for connect in pdb.models[-1].connects: i = connect[0] for j in connect[1:]: - if i in atomByNumber and j in atomByNumber: + if i in atomByNumber and j in atomByNumber: if atomByNumber[i].element is not None and atomByNumber[j].element is not None: - if atomByNumber[i].element.symbol not in metalElements and atomByNumber[j].element.symbol not in metalElements: - connectBonds.append((atomByNumber[i], atomByNumber[j])) - elif atomByNumber[i].element.symbol in metalElements and atomByNumber[j].residue.name not in PDBFile._standardResidues: - connectBonds.append((atomByNumber[i], atomByNumber[j])) - elif atomByNumber[j].element.symbol in metalElements and atomByNumber[i].residue.name not in PDBFile._standardResidues: - connectBonds.append((atomByNumber[i], atomByNumber[j])) + if ( + atomByNumber[i].element.symbol not in metalElements + and atomByNumber[j].element.symbol not in metalElements + ): + connectBonds.append((atomByNumber[i], atomByNumber[j])) + elif ( + atomByNumber[i].element.symbol in metalElements + and atomByNumber[j].residue.name not in PDBFile._standardResidues + ): + connectBonds.append((atomByNumber[i], atomByNumber[j])) + elif ( + atomByNumber[j].element.symbol in metalElements + and atomByNumber[i].residue.name not in PDBFile._standardResidues + ): + connectBonds.append((atomByNumber[i], atomByNumber[j])) else: - connectBonds.append((atomByNumber[i], atomByNumber[j])) + connectBonds.append((atomByNumber[i], atomByNumber[j])) if len(connectBonds) > 0: # Only add bonds that don't already exist. existingBonds = set(top.bonds()) @@ -218,9 +298,12 @@ def getPositions(self, asNumpy=False, frame=0): """ if asNumpy: if self._numpyPositions is None: - self._numpyPositions = [None]*len(self._positions) + self._numpyPositions = [None] * len(self._positions) if self._numpyPositions[frame] is None: - self._numpyPositions[frame] = Quantity(np.array(self._positions[frame].value_in_unit(nanometers)), nanometers) + self._numpyPositions[frame] = Quantity( + np.array(self._positions[frame].value_in_unit(nanometers)), + nanometers, + ) return self._numpyPositions[frame] return self._positions[frame] @@ -228,31 +311,31 @@ def getPositions(self, asNumpy=False, frame=0): def _loadNameReplacementTables(): """Load the list of atom and residue name replacements.""" if len(PDBFile._residueNameReplacements) == 0: - tree = etree.parse(os.path.join(os.path.dirname(__file__), 'data', 'pdbNames.xml')) + tree = etree.parse(os.path.join(os.path.dirname(__file__), "data", "pdbNames.xml")) allResidues = {} proteinResidues = {} nucleicAcidResidues = {} - for residue in tree.getroot().findall('Residue'): - name = residue.attrib['name'] - if name == 'All': + for residue in tree.getroot().findall("Residue"): + name = residue.attrib["name"] + if name == "All": PDBFile._parseResidueAtoms(residue, allResidues) - elif name == 'Protein': + elif name == "Protein": PDBFile._parseResidueAtoms(residue, proteinResidues) - elif name == 'Nucleic': + elif name == "Nucleic": PDBFile._parseResidueAtoms(residue, nucleicAcidResidues) for atom in allResidues: proteinResidues[atom] = allResidues[atom] nucleicAcidResidues[atom] = allResidues[atom] - for residue in tree.getroot().findall('Residue'): - name = residue.attrib['name'] + for residue in tree.getroot().findall("Residue"): + name = residue.attrib["name"] for id in residue.attrib: - if id == 'name' or id.startswith('alt'): + if id == "name" or id.startswith("alt"): PDBFile._residueNameReplacements[residue.attrib[id]] = name - if 'type' not in residue.attrib: + if "type" not in residue.attrib: atoms = copy(allResidues) - elif residue.attrib['type'] == 'Protein': + elif residue.attrib["type"] == "Protein": atoms = copy(proteinResidues) - elif residue.attrib['type'] == 'Nucleic': + elif residue.attrib["type"] == "Nucleic": atoms = copy(nucleicAcidResidues) else: atoms = copy(allResidues) @@ -261,13 +344,19 @@ def _loadNameReplacementTables(): @staticmethod def _parseResidueAtoms(residue, map): - for atom in residue.findall('Atom'): - name = atom.attrib['name'] + for atom in residue.findall("Atom"): + name = atom.attrib["name"] for id in atom.attrib: map[atom.attrib[id]] = name @staticmethod - def writeFile(topology, positions, file=sys.stdout, keepIds=False, extraParticleIdentifier='EP'): + def writeFile( + topology, + positions, + file=sys.stdout, + keepIds=False, + extraParticleIdentifier="EP", + ): """Write a PDB file containing a single model. Parameters @@ -287,7 +376,13 @@ def writeFile(topology, positions, file=sys.stdout, keepIds=False, extraParticle String to write in the element column of the ATOM records for atoms whose element is None (extra particles) """ PDBFile.writeHeader(topology, file) - PDBFile.writeModel(topology, positions, file, keepIds=keepIds, extraParticleIdentifier=extraParticleIdentifier) + PDBFile.writeModel( + topology, + positions, + file, + keepIds=keepIds, + extraParticleIdentifier=extraParticleIdentifier, + ) PDBFile.writeFooter(topology, file) @staticmethod @@ -305,12 +400,29 @@ def writeHeader(topology, file=sys.stdout): vectors = topology.getPeriodicBoxVectors() if vectors is not None: a, b, c, alpha, beta, gamma = computeLengthsAndAngles(vectors) - RAD_TO_DEG = 180/math.pi - print("CRYST1%9.3f%9.3f%9.3f%7.2f%7.2f%7.2f P 1 1 " % ( - a*10, b*10, c*10, alpha*RAD_TO_DEG, beta*RAD_TO_DEG, gamma*RAD_TO_DEG), file=file) + RAD_TO_DEG = 180 / math.pi + print( + "CRYST1%9.3f%9.3f%9.3f%7.2f%7.2f%7.2f P 1 1 " + % ( + a * 10, + b * 10, + c * 10, + alpha * RAD_TO_DEG, + beta * RAD_TO_DEG, + gamma * RAD_TO_DEG, + ), + file=file, + ) @staticmethod - def writeModel(topology, positions, file=sys.stdout, modelIndex=None, keepIds=False, extraParticleIdentifier='EP'): + def writeModel( + topology, + positions, + file=sys.stdout, + modelIndex=None, + keepIds=False, + extraParticleIdentifier="EP", + ): """Write out a model to a PDB file. Parameters @@ -335,26 +447,30 @@ def writeModel(topology, positions, file=sys.stdout, modelIndex=None, keepIds=Fa """ if len(list(topology.atoms())) != len(positions): - raise ValueError('The number of positions must match the number of atoms') + raise ValueError("The number of positions must match the number of atoms") if is_quantity(positions): positions = positions.value_in_unit(angstroms) if any(math.isnan(norm(pos)) for pos in positions): - raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan') + raise ValueError( + "Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan" + ) if any(math.isinf(norm(pos)) for pos in positions): - raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan') + raise ValueError( + "Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan" + ) nonHeterogens = PDBFile._standardResidues[:] - nonHeterogens.remove('HOH') + nonHeterogens.remove("HOH") atomIndex = 1 posIndex = 0 if modelIndex is not None: print("MODEL %4d" % modelIndex, file=file) - for (chainIndex, chain) in enumerate(topology.chains()): + for chainIndex, chain in enumerate(topology.chains()): if keepIds and len(chain.id) == 1: chainName = chain.id else: - chainName = chr(ord('A')+chainIndex%26) + chainName = chr(ord("A") + chainIndex % 26) residues = list(chain.residues()) - for (resIndex, res) in enumerate(residues): + for resIndex, res in enumerate(residues): if len(res.name) > 3: resName = res.name[:3] else: @@ -362,7 +478,7 @@ def writeModel(topology, positions, file=sys.stdout, modelIndex=None, keepIds=Fa if keepIds and len(res.id) < 5: resId = res.id else: - resId = _formatIndex(resIndex+1, 4) + resId = _formatIndex(resIndex + 1, 4) if len(res.insertionCode) == 1: resIC = res.insertionCode else: @@ -377,22 +493,35 @@ def writeModel(topology, positions, file=sys.stdout, modelIndex=None, keepIds=Fa else: symbol = extraParticleIdentifier if len(atom.name) < 4 and atom.name[:1].isalpha() and len(symbol) < 2: - atomName = ' '+atom.name + atomName = " " + atom.name elif len(atom.name) > 4: atomName = atom.name[:4] else: atomName = atom.name coords = positions[posIndex] line = "%s%5s %-4s %3s %s%4s%1s %s%s%s 1.00 0.00 %2s " % ( - recordName, _formatIndex(atomIndex, 5), atomName, resName, chainName, resId, resIC, _format_83(coords[0]), - _format_83(coords[1]), _format_83(coords[2]), symbol) + recordName, + _formatIndex(atomIndex, 5), + atomName, + resName, + chainName, + resId, + resIC, + _format_83(coords[0]), + _format_83(coords[1]), + _format_83(coords[2]), + symbol, + ) if len(line) != 80: - raise ValueError('Fixed width overflow detected') + raise ValueError("Fixed width overflow detected") print(line, file=file) posIndex += 1 atomIndex += 1 - if resIndex == len(residues)-1: - print("TER %5s %3s %s%4s" % (_formatIndex(atomIndex, 5), resName, chainName, resId), file=file) + if resIndex == len(residues) - 1: + print( + "TER %5s %3s %s%4s" % (_formatIndex(atomIndex, 5), resName, chainName, resId), + file=file, + ) atomIndex += 1 if modelIndex is not None: print("ENDMDL", file=file) @@ -412,9 +541,17 @@ def writeFooter(topology, file=sys.stdout): conectBonds = [] for atom1, atom2 in topology.bonds(): - if atom1.residue.name not in PDBFile._standardResidues or atom2.residue.name not in PDBFile._standardResidues: + if ( + atom1.residue.name not in PDBFile._standardResidues + or atom2.residue.name not in PDBFile._standardResidues + ): conectBonds.append((atom1, atom2)) - elif atom1.name == 'SG' and atom2.name == 'SG' and atom1.residue.name == 'CYS' and atom2.residue.name == 'CYS': + elif ( + atom1.name == "SG" + and atom2.name == "SG" + and atom1.residue.name == "CYS" + and atom2.residue.name == "CYS" + ): conectBonds.append((atom1, atom2)) if len(conectBonds) > 0: @@ -449,7 +586,16 @@ def writeFooter(topology, file=sys.stdout): for index1 in sorted(atomBonds): bonded = atomBonds[index1] while len(bonded) > 4: - print("CONECT%5s%5s%5s%5s" % (_formatIndex(index1, 5), _formatIndex(bonded[0], 5), _formatIndex(bonded[1], 5), _formatIndex(bonded[2], 5)), file=file) + print( + "CONECT%5s%5s%5s%5s" + % ( + _formatIndex(index1, 5), + _formatIndex(bonded[0], 5), + _formatIndex(bonded[1], 5), + _formatIndex(bonded[2], 5), + ), + file=file, + ) del bonded[:4] line = "CONECT%5s" % _formatIndex(index1, 5) for index2 in bonded: @@ -464,19 +610,19 @@ def _format_83(f): gracefully degrade the precision by lopping off some of the decimal places. If it's much too large, we throw a ValueError""" if -999.999 < f < 9999.999: - return '%8.3f' % f + return "%8.3f" % f if -9999999 < f < 99999999: - return ('%8.3f' % f)[:8] - raise ValueError('coordinate "%s" could not be represented ' - 'in a width-8 field' % f) + return ("%8.3f" % f)[:8] + raise ValueError('coordinate "%s" could not be represented ' "in a width-8 field" % f) + def _formatIndex(index, places): """Create a string representation of an atom or residue index. If the value is larger than can fit in the available space, switch to hex. """ if index < 10**places: - format = f'%{places}d' + format = f"%{places}d" return format % index - format = f'%{places}X' - shiftedIndex = (index - 10**places + 10*16**(places-1)) % (16**places) - return format % shiftedIndex \ No newline at end of file + format = f"%{places}X" + shiftedIndex = (index - 10**places + 10 * 16 ** (places - 1)) % (16**places) + return format % shiftedIndex diff --git a/gufe/vendor/pdb_file/pdbstructure.py b/gufe/vendor/pdb_file/pdbstructure.py index 76c5f61a..cda68915 100644 --- a/gufe/vendor/pdb_file/pdbstructure.py +++ b/gufe/vendor/pdb_file/pdbstructure.py @@ -28,24 +28,24 @@ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from __future__ import absolute_import -from __future__ import print_function + __author__ = "Christopher M. Bruns" __version__ = "1.0" +import math +import sys +import warnings +from collections import OrderedDict + +import numpy as np import openmm.unit as unit from . import element from .unitcell import computePeriodicBoxVectors -import numpy as np -import warnings -import sys -import math -from collections import OrderedDict -class PdbStructure(object): +class PdbStructure: """ PdbStructure object holds a parsed Protein Data Bank format file. @@ -125,8 +125,7 @@ class PdbStructure(object): methods. """ - - def __init__(self, input_stream, load_all_models=False, extraParticleIdentifier='EP'): + def __init__(self, input_stream, load_all_models=False, extraParticleIdentifier="EP"): """Create a PDB model from a PDB file stream. Parameters @@ -160,16 +159,16 @@ def _load(self, input_stream): # Read one line at a time for pdb_line in input_stream: if not isinstance(pdb_line, str): - pdb_line = pdb_line.decode('utf-8') + pdb_line = pdb_line.decode("utf-8") command = pdb_line[:6] # Look for atoms if command == "ATOM " or command == "HETATM": self._add_atom(Atom(pdb_line, self, self.extraParticleIdentifier)) elif command == "CONECT": atoms = [_parse_atom_index(pdb_line[6:11])] - for pos in (11,16,21,26): + for pos in (11, 16, 21, 26): try: - atoms.append(_parse_atom_index(pdb_line[pos:pos+5])) + atoms.append(_parse_atom_index(pdb_line[pos : pos + 5])) except: pass self._current_model.connects.append(atoms) @@ -192,21 +191,30 @@ def _load(self, input_stream): self._current_model._current_chain._add_ter_record() self._reset_residue_numbers() elif command == "CRYST1": - a_length = float(pdb_line[6:15])*0.1 - b_length = float(pdb_line[15:24])*0.1 - c_length = float(pdb_line[24:33])*0.1 - alpha = float(pdb_line[33:40])*math.pi/180.0 - beta = float(pdb_line[40:47])*math.pi/180.0 - gamma = float(pdb_line[47:54])*math.pi/180.0 + a_length = float(pdb_line[6:15]) * 0.1 + b_length = float(pdb_line[15:24]) * 0.1 + c_length = float(pdb_line[24:33]) * 0.1 + alpha = float(pdb_line[33:40]) * math.pi / 180.0 + beta = float(pdb_line[40:47]) * math.pi / 180.0 + gamma = float(pdb_line[47:54]) * math.pi / 180.0 if 0 not in (a_length, b_length, c_length): - self._periodic_box_vectors = computePeriodicBoxVectors(a_length, b_length, c_length, alpha, beta, gamma) + self._periodic_box_vectors = computePeriodicBoxVectors( + a_length, b_length, c_length, alpha, beta, gamma + ) elif command == "SEQRES": chain_id = pdb_line[11] if len(self.sequences) == 0 or chain_id != self.sequences[-1].chain_id: self.sequences.append(Sequence(chain_id)) self.sequences[-1].residues.extend(pdb_line[19:].split()) elif command == "MODRES": - self.modified_residues.append(ModifiedResidue(pdb_line[16], int(pdb_line[18:22]), pdb_line[12:15].strip(), pdb_line[24:27].strip())) + self.modified_residues.append( + ModifiedResidue( + pdb_line[16], + int(pdb_line[18:22]), + pdb_line[12:15].strip(), + pdb_line[24:27].strip(), + ) + ) self._finalize() def _reset_atom_numbers(self): @@ -248,30 +256,25 @@ def __getitem__(self, model_number): return self.models_by_number[model_number] def __iter__(self): - for model in self.models: - yield model + yield from self.models def iter_models(self, use_all_models=False): if use_all_models: - for model in self: - yield model + yield from self elif len(self.models) > 0: yield self.models[0] def iter_chains(self, use_all_models=False): for model in self.iter_models(use_all_models): - for chain in model.iter_chains(): - yield chain + yield from model.iter_chains() def iter_residues(self, use_all_models=False): for model in self.iter_models(use_all_models): - for res in model.iter_residues(): - yield res + yield from model.iter_residues() def iter_atoms(self, use_all_models=False): for model in self.iter_models(use_all_models): - for atom in model.iter_atoms(): - yield atom + yield from model.iter_atoms() def iter_positions(self, use_all_models=False, include_alt_loc=False): """ @@ -285,15 +288,13 @@ def iter_positions(self, use_all_models=False, include_alt_loc=False): Get all positions for each atom, or just the first one. """ for model in self.iter_models(use_all_models): - for loc in model.iter_positions(include_alt_loc): - yield loc + yield from model.iter_positions(include_alt_loc) def __len__(self): return len(self.models) def _add_atom(self, atom): - """ - """ + """ """ if self._current_model is None: self._add_model(Model(0)) atom.model_number = self._current_model.number @@ -310,14 +311,17 @@ def get_periodic_box_vectors(self): return self._periodic_box_vectors -class Sequence(object): +class Sequence: """Sequence holds the sequence of a chain, as specified by SEQRES records.""" + def __init__(self, chain_id): self.chain_id = chain_id self.residues = [] -class ModifiedResidue(object): + +class ModifiedResidue: """ModifiedResidue holds information about a modified residue, as specified by a MODRES record.""" + def __init__(self, chain_id, number, residue_name, standard_name): self.chain_id = chain_id self.number = number @@ -325,12 +329,13 @@ def __init__(self, chain_id, number, residue_name, standard_name): self.standard_name = standard_name -class Model(object): +class Model: """Model holds one model of a PDB structure. NMR structures usually have multiple models. This represents one of them. """ + def __init__(self, model_number=1): self.number = model_number self.chains = [] @@ -339,8 +344,7 @@ def __init__(self, model_number=1): self.connects = [] def _add_atom(self, atom): - """ - """ + """ """ if len(self.chains) == 0: self._add_chain(Chain(atom.chain_id)) # Create a new chain if the chain id has changed @@ -373,23 +377,19 @@ def __iter__(self): return iter(self.chains) def iter_chains(self): - for chain in self: - yield chain + yield from self def iter_residues(self): for chain in self: - for res in chain.iter_residues(): - yield res + yield from chain.iter_residues() def iter_atoms(self): for chain in self: - for atom in chain.iter_atoms(): - yield atom + yield from chain.iter_atoms() def iter_positions(self, include_alt_loc=False): for chain in self: - for loc in chain.iter_positions(include_alt_loc): - yield loc + yield from chain.iter_positions(include_alt_loc) def __len__(self): return len(self.chains) @@ -404,9 +404,9 @@ def _finalize(self): for chain in self.chains: chain._finalize() - - class AtomSerialNumber(object): + class AtomSerialNumber: """pdb.Model inner class for pass-by-reference incrementable serial number""" + def __init__(self, val): self.val = val @@ -414,8 +414,8 @@ def increment(self): self.val += 1 -class Chain(object): - def __init__(self, chain_id=' '): +class Chain: + def __init__(self, chain_id=" "): self.chain_id = chain_id self.residues = [] self.has_ter_record = False @@ -424,26 +424,57 @@ def __init__(self, chain_id=' '): self.residues_by_number = {} def _add_atom(self, atom): - """ - """ + """ """ # Create a residue if none have been created if len(self.residues) == 0: - self._add_residue(Residue(atom.residue_name_with_spaces, atom.residue_number, atom.insertion_code, atom.alternate_location_indicator)) + self._add_residue( + Residue( + atom.residue_name_with_spaces, + atom.residue_number, + atom.insertion_code, + atom.alternate_location_indicator, + ) + ) # Create a residue if the residue information has changed elif self._current_residue.number != atom.residue_number: - self._add_residue(Residue(atom.residue_name_with_spaces, atom.residue_number, atom.insertion_code, atom.alternate_location_indicator)) + self._add_residue( + Residue( + atom.residue_name_with_spaces, + atom.residue_number, + atom.insertion_code, + atom.alternate_location_indicator, + ) + ) elif self._current_residue.insertion_code != atom.insertion_code: - self._add_residue(Residue(atom.residue_name_with_spaces, atom.residue_number, atom.insertion_code, atom.alternate_location_indicator)) + self._add_residue( + Residue( + atom.residue_name_with_spaces, + atom.residue_number, + atom.insertion_code, + atom.alternate_location_indicator, + ) + ) elif self._current_residue.name_with_spaces == atom.residue_name_with_spaces: # This is a normal case: number, name, and iCode have not changed pass - elif atom.alternate_location_indicator != ' ': + elif atom.alternate_location_indicator != " ": # OK - this is a point mutation, Residue._add_atom will know what to do pass - else: # Residue name does not match + else: # Residue name does not match # Only residue name does not match - warnings.warn("WARNING: two consecutive residues with same number (%s, %s)" % (atom, self._current_residue.atoms[-1])) - self._add_residue(Residue(atom.residue_name_with_spaces, atom.residue_number, atom.insertion_code, atom.alternate_location_indicator)) + warnings.warn( + "WARNING: two consecutive residues with same number ({}, {})".format( + atom, self._current_residue.atoms[-1] + ) + ) + self._add_residue( + Residue( + atom.residue_name_with_spaces, + atom.residue_number, + atom.insertion_code, + atom.alternate_location_indicator, + ) + ) self._current_residue._add_atom(atom) def _add_residue(self, residue): @@ -463,14 +494,24 @@ def write(self, next_serial_number, output_stream=sys.stdout): residue.write(next_serial_number, output_stream) if self.has_ter_record: r = self.residues[-1] - print("TER %5d %3s %1s%4d%1s" % (next_serial_number.val, r.name_with_spaces, self.chain_id, r.number, r.insertion_code), file=output_stream) + print( + "TER %5d %3s %1s%4d%1s" + % ( + next_serial_number.val, + r.name_with_spaces, + self.chain_id, + r.number, + r.insertion_code, + ), + file=output_stream, + ) next_serial_number.increment() def _add_ter_record(self): self.has_ter_record = True self._finalize() - def get_residue(self, residue_number, insertion_code=' '): + def get_residue(self, residue_number, insertion_code=" "): return self.residues_by_num_icode[str(residue_number) + insertion_code] def __contains__(self, residue_number): @@ -481,22 +522,18 @@ def __getitem__(self, residue_number): return self.residues_by_number[residue_number] def __iter__(self): - for res in self.residues: - yield res + yield from self.residues def iter_residues(self): - for res in self: - yield res + yield from self def iter_atoms(self): for res in self: - for atom in res: - yield atom; + yield from res def iter_positions(self, include_alt_loc=False): for res in self: - for loc in res.iter_positions(include_alt_loc): - yield loc + yield from res.iter_positions(include_alt_loc) def __len__(self): return len(self.residues) @@ -508,8 +545,8 @@ def _finalize(self): residue._finalize() -class Residue(object): - def __init__(self, name, number, insertion_code=' ', primary_alternate_location_indicator=' '): +class Residue: + def __init__(self, name, number, insertion_code=" ", primary_alternate_location_indicator=" "): alt_loc = primary_alternate_location_indicator self.primary_location_id = alt_loc self.locations = {} @@ -524,30 +561,35 @@ def __init__(self, name, number, insertion_code=' ', primary_alternate_location_ self._current_atom = None def _add_atom(self, atom): - """ - """ + """ """ alt_loc = atom.alternate_location_indicator if alt_loc not in self.locations: self.locations[alt_loc] = Residue.Location(alt_loc, atom.residue_name_with_spaces) assert atom.residue_number == self.number assert atom.insertion_code == self.insertion_code # Check whether this is an existing atom with another position - if (atom.name_with_spaces in self.atoms_by_name): + if atom.name_with_spaces in self.atoms_by_name: old_atom = self.atoms_by_name[atom.name_with_spaces] # Unless this is a duplicated atom (warn about file error) if atom.alternate_location_indicator in old_atom.locations: - warnings.warn("WARNING: duplicate atom (%s, %s)" % (atom, old_atom._pdb_string(old_atom.serial_number, atom.alternate_location_indicator))) + warnings.warn( + "WARNING: duplicate atom (%s, %s)" + % ( + atom, + old_atom._pdb_string(old_atom.serial_number, atom.alternate_location_indicator), + ) + ) else: for alt_loc, position in atom.locations.items(): old_atom.locations[alt_loc] = position - return # no new atom added + return # no new atom added # actually use new atom self.atoms_by_name[atom.name] = atom self.atoms_by_name[atom.name_with_spaces] = atom self.atoms.append(atom) self._current_atom = atom - def write(self, next_serial_number, output_stream=sys.stdout, alt_loc = "*"): + def write(self, next_serial_number, output_stream=sys.stdout, alt_loc="*"): for atom in self.atoms: atom.write(next_serial_number, output_stream, alt_loc) @@ -567,19 +609,26 @@ def set_name_with_spaces(self, name, alt_loc=None): loc = self.locations[alt_loc] loc.name_with_spaces = name loc.name = name.strip() + def get_name_with_spaces(self, alt_loc=None): if alt_loc is None: alt_loc = self.primary_location_id loc = self.locations[alt_loc] return loc.name_with_spaces - name_with_spaces = property(get_name_with_spaces, set_name_with_spaces, doc='four-character residue name including spaces') + + name_with_spaces = property( + get_name_with_spaces, + set_name_with_spaces, + doc="four-character residue name including spaces", + ) def get_name(self, alt_loc=None): if alt_loc is None: alt_loc = self.primary_location_id loc = self.locations[alt_loc] return loc.name - name = property(get_name, doc='residue name') + + name = property(get_name, doc="residue name") def get_atom(self, atom_name): return self.atoms_by_name[atom_name] @@ -613,8 +662,7 @@ def __iter__(self): ATOM 192 CB CYS A 42 38.949 -6.825 12.002 1.00 9.67 C ATOM 193 SG CYS A 42 37.557 -7.514 12.922 1.00 20.12 S """ - for atom in self.iter_atoms(): - yield atom + yield from self.iter_atoms() # Three possibilities: primary alt_loc, certain alt_loc, or all alt_locs def iter_atoms(self, alt_loc=None): @@ -628,10 +676,10 @@ def iter_atoms(self, alt_loc=None): locs = list(alt_loc) # If an atom has any location in alt_loc, emit the atom for atom in self.atoms: - use_atom = False # start pessimistic + use_atom = False # start pessimistic for loc2 in atom.locations.keys(): # print "#%s#%s" % (loc2,locs) - if locs is None: # means all locations + if locs is None: # means all locations use_atom = True elif loc2 in locs: use_atom = True @@ -667,8 +715,7 @@ def iter_positions(self, include_alt_loc=False): """ for atom in self: if include_alt_loc: - for loc in atom.iter_positions(): - yield loc + yield from atom.iter_positions() else: yield atom.position @@ -680,15 +727,16 @@ class Location: """ Inner class of residue to allow different residue names for different alternate_locations. """ + def __init__(self, alternate_location_indicator, residue_name_with_spaces): self.alternate_location_indicator = alternate_location_indicator self.residue_name_with_spaces = residue_name_with_spaces -class Atom(object): - """Atom represents one atom in a PDB structure. - """ - def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier='EP'): +class Atom: + """Atom represents one atom in a PDB structure.""" + + def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier="EP"): """Create a new pdb.Atom from an ATOM or HETATM line. Example line: @@ -741,7 +789,7 @@ def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier='EP'): if possible_fourth_character != " ": # Fourth character should only be there if official 3 are already full if len(self.residue_name_with_spaces.strip()) != 3: - raise ValueError('Misaligned residue name: %s' % pdb_line) + raise ValueError("Misaligned residue name: %s" % pdb_line) self.residue_name_with_spaces += possible_fourth_character self.residue_name = self.residue_name_with_spaces.strip() @@ -754,7 +802,11 @@ def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier='EP'): except: # When VMD runs out of hex values it starts filling the residue ID field with ****. # Look at the most recent atoms to figure out whether this is a new residue or not. - if pdbstructure._current_model is None or pdbstructure._current_model._current_chain is None or pdbstructure._current_model._current_chain._current_residue is None: + if ( + pdbstructure._current_model is None + or pdbstructure._current_model._current_chain is None + or pdbstructure._current_model._current_chain._current_residue is None + ): # This is the first residue in the model. self.residue_number = pdbstructure._next_residue_number else: @@ -781,17 +833,25 @@ def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier='EP'): except: temperature_factor = unit.Quantity(0.0, unit.angstroms**2) self.locations = {} - loc = Atom.Location(alternate_location_indicator, unit.Quantity(np.array([x,y,z]), unit.angstroms), occupancy, temperature_factor, self.residue_name_with_spaces) + loc = Atom.Location( + alternate_location_indicator, + unit.Quantity(np.array([x, y, z]), unit.angstroms), + occupancy, + temperature_factor, + self.residue_name_with_spaces, + ) self.locations[alternate_location_indicator] = loc self.default_location_id = alternate_location_indicator # segment id, element_symbol, and formal_charge are not always present self.segment_id = pdb_line[72:76].strip() self.element_symbol = pdb_line[76:78].strip() - try: self.formal_charge = int(pdb_line[78:80]) - except ValueError: self.formal_charge = None + try: + self.formal_charge = int(pdb_line[78:80]) + except ValueError: + self.formal_charge = None # figure out atom element if self.element_symbol == extraParticleIdentifier: - self.element = 'EP' + self.element = "EP" else: try: # Try to find a sensible element symbol from columns 76-77 @@ -799,8 +859,8 @@ def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier='EP'): except KeyError: self.element = None if pdbstructure is not None: - pdbstructure._next_atom_number = self.serial_number+1 - pdbstructure._next_residue_number = self.residue_number+1 + pdbstructure._next_atom_number = self.serial_number + 1 + pdbstructure._next_residue_number = self.residue_number + 1 def iter_locations(self): """ @@ -835,8 +895,7 @@ def iter_coordinates(self): 22.607 A 20.046 A """ - for coord in self.position: - yield coord + yield from self.position # Hide existence of multiple alternate locations to avoid scaring casual users def get_location(self, location_id=None): @@ -844,38 +903,51 @@ def get_location(self, location_id=None): if id is None: id = self.default_location_id return self.locations[id] + def set_location(self, new_location, location_id=None): id = location_id if id is None: id = self.default_location_id self.locations[id] = new_location - location = property(get_location, set_location, doc='default Atom.Location object') + + location = property(get_location, set_location, doc="default Atom.Location object") def get_position(self): return self.location.position + def set_position(self, coords): self.location.position = coords - position = property(get_position, set_position, doc='orthogonal coordinates') + + position = property(get_position, set_position, doc="orthogonal coordinates") def get_alternate_location_indicator(self): return self.location.alternate_location_indicator + alternate_location_indicator = property(get_alternate_location_indicator) def get_occupancy(self): return self.location.occupancy + occupancy = property(get_occupancy) def get_temperature_factor(self): return self.location.temperature_factor + temperature_factor = property(get_temperature_factor) - def get_x(self): return self.position[0] + def get_x(self): + return self.position[0] + x = property(get_x) - def get_y(self): return self.position[1] + def get_y(self): + return self.position[1] + y = property(get_y) - def get_z(self): return self.position[2] + def get_z(self): + return self.position[2] + z = property(get_z) def _pdb_string(self, serial_number=None, alternate_location_indicator=None): @@ -893,26 +965,32 @@ def _pdb_string(self, serial_number=None, alternate_location_indicator=None): long_res_name += " " assert len(long_res_name) == 4 names = "%-6s%5d %4s%1s%4s%1s%4d%1s " % ( - self.record_name, serial_number, \ - self.name_with_spaces, alternate_location_indicator, \ - long_res_name, self.chain_id, \ - self.residue_number, self.insertion_code) - numbers = "%8.3f%8.3f%8.3f%6.2f%6.2f " % ( - self.x.value_in_unit(unit.angstroms), \ - self.y.value_in_unit(unit.angstroms), \ - self.z.value_in_unit(unit.angstroms), \ - self.occupancy, \ - self.temperature_factor.value_in_unit(unit.angstroms * unit.angstroms)) - end = "%-4s%2s" % (\ - self.segment_id, self.element_symbol) + self.record_name, + serial_number, + self.name_with_spaces, + alternate_location_indicator, + long_res_name, + self.chain_id, + self.residue_number, + self.insertion_code, + ) + numbers = "{:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f} ".format( + self.x.value_in_unit(unit.angstroms), + self.y.value_in_unit(unit.angstroms), + self.z.value_in_unit(unit.angstroms), + self.occupancy, + self.temperature_factor.value_in_unit(unit.angstroms * unit.angstroms), + ) + end = "%-4s%2s" % (self.segment_id, self.element_symbol) formal_charge = " " - if (self.formal_charge != None): formal_charge = "%+2d" % self.formal_charge - return names+numbers+end+formal_charge + if self.formal_charge != None: + formal_charge = "%+2d" % self.formal_charge + return names + numbers + end + formal_charge def __str__(self): return self._pdb_string(self.serial_number, self.alternate_location_indicator) - def write(self, next_serial_number, output_stream=sys.stdout, alt_loc = "*"): + def write(self, next_serial_number, output_stream=sys.stdout, alt_loc="*"): """ alt_loc = "*" means write all alternate locations alt_loc = None means write just the primary location @@ -935,18 +1013,26 @@ def set_name_with_spaces(self, name): assert len(name) == 4 self._name_with_spaces = name self._name = name.strip() + def get_name_with_spaces(self): return self._name_with_spaces - name_with_spaces = property(get_name_with_spaces, set_name_with_spaces, doc='four-character residue name including spaces') + + name_with_spaces = property( + get_name_with_spaces, + set_name_with_spaces, + doc="four-character residue name including spaces", + ) def get_name(self): return self._name - name = property(get_name, doc='residue name') - class Location(object): + name = property(get_name, doc="residue name") + + class Location: """ Inner class of Atom for holding alternate locations """ + def __init__(self, alt_loc, position, occupancy, temperature_factor, residue_name): self.alternate_location_indicator = alt_loc self.position = position @@ -968,8 +1054,7 @@ def __iter__(self): 2 A 3 A """ - for coord in self.position: - yield coord + yield from self.position def __str__(self): return str(self.position) @@ -982,14 +1067,17 @@ def _parse_atom_index(index): except: return int(index, 16) - 0xA0000 + 100000 + # run module directly for testing -if __name__=='__main__': +if __name__ == "__main__": # Test the examples in the docstrings - import doctest, sys + import doctest + import sys + doctest.testmod(sys.modules[__name__]) - import os import gzip + import os import re import time @@ -1006,7 +1094,7 @@ def _parse_atom_index(index): assert a.residue_number == 299 assert a.insertion_code == " " assert a.alternate_location_indicator == " " - assert a.x == 6.167 * unit.angstroms + assert a.x == 6.167 * unit.angstroms assert a.y == 22.607 * unit.angstroms assert a.z == 20.046 * unit.angstroms assert a.occupancy == 1.00 @@ -1021,8 +1109,9 @@ def _parse_atom_index(index): # misaligned residue name - bad try: a = Atom("ATOM 2209 CB TYRA 299 6.167 22.607 20.046 1.00 8.12 C") - assert(False) - except ValueError: pass + assert False + except ValueError: + pass # four character residue name -- not so bad a = Atom("ATOM 2209 CB NTYRA 299 6.167 22.607 20.046 1.00 8.12 C") @@ -1070,18 +1159,19 @@ def parse_one_pdb(pdb_file_name): subdir = "ae" full_subdir = os.path.join(pdb_dir, subdir) for pdb_file in os.listdir(full_subdir): - if not re.match("pdb.%2s.\.ent\.gz" % subdir , pdb_file): + if not re.match(r"pdb.%2s.\.ent\.gz" % subdir, pdb_file): continue full_pdb_file = os.path.join(full_subdir, pdb_file) parse_one_pdb(full_pdb_file) if parse_entire_pdb: for subdir in os.listdir(pdb_dir): - if not len(subdir) == 2: continue + if not len(subdir) == 2: + continue full_subdir = os.path.join(pdb_dir, subdir) if not os.path.isdir(full_subdir): continue for pdb_file in os.listdir(full_subdir): - if not re.match("pdb.%2s.\.ent\.gz" % subdir , pdb_file): + if not re.match(r"pdb.%2s.\.ent\.gz" % subdir, pdb_file): continue full_pdb_file = os.path.join(full_subdir, pdb_file) parse_one_pdb(full_pdb_file) @@ -1098,4 +1188,4 @@ def parse_one_pdb(pdb_file_name): print("%d residues found" % residue_count) print("%d chains found" % chain_count) print("%d models found" % model_count) - print("%d structures found" % structure_count) \ No newline at end of file + print("%d structures found" % structure_count) diff --git a/gufe/vendor/pdb_file/pdbxfile.py b/gufe/vendor/pdb_file/pdbxfile.py index 60d83246..e5d6433f 100644 --- a/gufe/vendor/pdb_file/pdbxfile.py +++ b/gufe/vendor/pdb_file/pdbxfile.py @@ -28,27 +28,25 @@ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from __future__ import division, absolute_import, print_function __author__ = "Peter Eastman" __version__ = "2.0" -from openmm.unit import nanometers, angstroms, is_quantity, norm, Quantity - - -import sys import math -import numpy as np - +import sys from datetime import date +import numpy as np +from openmm.unit import Quantity, angstroms, is_quantity, nanometers, norm + +from . import element as elem +from .pdbfile import PDBFile from .PdbxReader import PdbxReader -from .unitcell import computePeriodicBoxVectors, computeLengthsAndAngles from .topology import Topology -from .pdbfile import PDBFile -from . import element as elem +from .unitcell import computeLengthsAndAngles, computePeriodicBoxVectors + -class PDBxFile(object): +class PDBxFile: """PDBxFile parses a PDBx/mmCIF file and constructs a Topology and a set of atom positions from it.""" def __init__(self, file): @@ -84,52 +82,56 @@ def __init__(self, file): # Build the topology. - atomData = block.getObj('atom_site') - atomNameCol = atomData.getAttributeIndex('auth_atom_id') + atomData = block.getObj("atom_site") + atomNameCol = atomData.getAttributeIndex("auth_atom_id") if atomNameCol == -1: - atomNameCol = atomData.getAttributeIndex('label_atom_id') - atomIdCol = atomData.getAttributeIndex('id') - resNameCol = atomData.getAttributeIndex('auth_comp_id') + atomNameCol = atomData.getAttributeIndex("label_atom_id") + atomIdCol = atomData.getAttributeIndex("id") + resNameCol = atomData.getAttributeIndex("auth_comp_id") if resNameCol == -1: - resNameCol = atomData.getAttributeIndex('label_comp_id') - resNumCol = atomData.getAttributeIndex('auth_seq_id') + resNameCol = atomData.getAttributeIndex("label_comp_id") + resNumCol = atomData.getAttributeIndex("auth_seq_id") if resNumCol == -1: - resNumCol = atomData.getAttributeIndex('label_seq_id') - resInsertionCol = atomData.getAttributeIndex('pdbx_PDB_ins_code') - chainIdCol = atomData.getAttributeIndex('auth_asym_id') + resNumCol = atomData.getAttributeIndex("label_seq_id") + resInsertionCol = atomData.getAttributeIndex("pdbx_PDB_ins_code") + chainIdCol = atomData.getAttributeIndex("auth_asym_id") if chainIdCol == -1: - chainIdCol = atomData.getAttributeIndex('label_asym_id') + chainIdCol = atomData.getAttributeIndex("label_asym_id") altChainIdCol = -1 else: - altChainIdCol = atomData.getAttributeIndex('label_asym_id') + altChainIdCol = atomData.getAttributeIndex("label_asym_id") if altChainIdCol != -1: # Figure out which column is best to use for chain IDs. - - idSet = set(row[chainIdCol] for row in atomData.getRowList()) - altIdSet = set(row[altChainIdCol] for row in atomData.getRowList()) + + idSet = {row[chainIdCol] for row in atomData.getRowList()} + altIdSet = {row[altChainIdCol] for row in atomData.getRowList()} if len(altIdSet) > len(idSet): chainIdCol, altChainIdCol = (altChainIdCol, chainIdCol) - elementCol = atomData.getAttributeIndex('type_symbol') - altIdCol = atomData.getAttributeIndex('label_alt_id') - modelCol = atomData.getAttributeIndex('pdbx_PDB_model_num') - xCol = atomData.getAttributeIndex('Cartn_x') - yCol = atomData.getAttributeIndex('Cartn_y') - zCol = atomData.getAttributeIndex('Cartn_z') + elementCol = atomData.getAttributeIndex("type_symbol") + altIdCol = atomData.getAttributeIndex("label_alt_id") + modelCol = atomData.getAttributeIndex("pdbx_PDB_model_num") + xCol = atomData.getAttributeIndex("Cartn_x") + yCol = atomData.getAttributeIndex("Cartn_y") + zCol = atomData.getAttributeIndex("Cartn_z") lastChainId = None lastAltChainId = None lastResId = None - lastInsertionCode = '' + lastInsertionCode = "" atomTable = {} atomsInResidue = set() models = [] for row in atomData.getRowList(): - atomKey = ((row[resNumCol], row[chainIdCol], row[atomNameCol])) - model = ('1' if modelCol == -1 else row[modelCol]) + atomKey = (row[resNumCol], row[chainIdCol], row[atomNameCol]) + model = "1" if modelCol == -1 else row[modelCol] if model not in models: models.append(model) self._positions.append([]) modelIndex = models.index(model) - if row[altIdCol] != '.' and atomKey in atomTable and len(self._positions[modelIndex]) > atomTable[atomKey].index: + if ( + row[altIdCol] != "." + and atomKey in atomTable + and len(self._positions[modelIndex]) > atomTable[atomKey].index + ): # This row is an alternate position for an existing atom, so ignore it. continue @@ -137,11 +139,11 @@ def __init__(self, file): # This row defines a new atom. if resInsertionCol == -1: - insertionCode = '' + insertionCode = "" else: insertionCode = row[resInsertionCol] - if insertionCode in ('.', '?'): - insertionCode = '' + if insertionCode in (".", "?"): + insertionCode = "" if lastChainId != row[chainIdCol] or (altChainIdCol != -1 and lastAltChainId != row[altChainIdCol]): # The start of a new chain. chain = top.addChain(row[chainIdCol]) @@ -149,9 +151,14 @@ def __init__(self, file): lastResId = None if altChainIdCol != -1: lastAltChainId = row[altChainIdCol] - if lastResId != row[resNumCol] or lastChainId != row[chainIdCol] or lastInsertionCode != insertionCode or (lastResId == '.' and row[atomNameCol] in atomsInResidue): + if ( + lastResId != row[resNumCol] + or lastChainId != row[chainIdCol] + or lastInsertionCode != insertionCode + or (lastResId == "." and row[atomNameCol] in atomsInResidue) + ): # The start of a new residue. - resId = (None if resNumCol == -1 else row[resNumCol]) + resId = None if resNumCol == -1 else row[resNumCol] resIC = insertionCode res = top.addResidue(row[resNameCol], chain, resId, resIC) lastResId = row[resNumCol] @@ -171,12 +178,18 @@ def __init__(self, file): try: atom = atomTable[atomKey] except KeyError: - raise ValueError('Unknown atom %s in residue %s %s for model %s' % (row[atomNameCol], row[resNameCol], row[resNumCol], model)) + raise ValueError( + "Unknown atom %s in residue %s %s for model %s" + % (row[atomNameCol], row[resNameCol], row[resNumCol], model) + ) if atom.index != len(self._positions[modelIndex]): - raise ValueError('Atom %s for model %s does not match the order of atoms for model %s' % (row[atomIdCol], model, models[0])) - self._positions[modelIndex].append(np.array([float(row[xCol]), float(row[yCol]), float(row[zCol])])*0.1) + raise ValueError( + "Atom %s for model %s does not match the order of atoms for model %s" + % (row[atomIdCol], model, models[0]) + ) + self._positions[modelIndex].append(np.array([float(row[xCol]), float(row[yCol]), float(row[zCol])]) * 0.1) for i in range(len(self._positions)): - self._positions[i] = self._positions[i]*nanometers + self._positions[i] = self._positions[i] * nanometers ## The atom positions read from the PDBx/mmCIF file. If the file contains multiple frames, these are the positions in the first frame. self.positions = self._positions[0] self.topology.createStandardBonds() @@ -184,28 +197,34 @@ def __init__(self, file): # Record unit cell information, if present. - cell = block.getObj('cell') + cell = block.getObj("cell") if cell is not None and cell.getRowCount() > 0: row = cell.getRow(0) - (a, b, c) = [float(row[cell.getAttributeIndex(attribute)])*0.1 for attribute in ('length_a', 'length_b', 'length_c')] - (alpha, beta, gamma) = [float(row[cell.getAttributeIndex(attribute)])*math.pi/180.0 for attribute in ('angle_alpha', 'angle_beta', 'angle_gamma')] + (a, b, c) = ( + float(row[cell.getAttributeIndex(attribute)]) * 0.1 + for attribute in ("length_a", "length_b", "length_c") + ) + (alpha, beta, gamma) = ( + float(row[cell.getAttributeIndex(attribute)]) * math.pi / 180.0 + for attribute in ("angle_alpha", "angle_beta", "angle_gamma") + ) self.topology.setPeriodicBoxVectors(computePeriodicBoxVectors(a, b, c, alpha, beta, gamma)) # Add bonds based on struct_conn records. - connectData = block.getObj('struct_conn') + connectData = block.getObj("struct_conn") if connectData is not None: - res1Col = connectData.getAttributeIndex('ptnr1_label_seq_id') - res2Col = connectData.getAttributeIndex('ptnr2_label_seq_id') - atom1Col = connectData.getAttributeIndex('ptnr1_label_atom_id') - atom2Col = connectData.getAttributeIndex('ptnr2_label_atom_id') - asym1Col = connectData.getAttributeIndex('ptnr1_label_asym_id') - asym2Col = connectData.getAttributeIndex('ptnr2_label_asym_id') - typeCol = connectData.getAttributeIndex('conn_type_id') + res1Col = connectData.getAttributeIndex("ptnr1_label_seq_id") + res2Col = connectData.getAttributeIndex("ptnr2_label_seq_id") + atom1Col = connectData.getAttributeIndex("ptnr1_label_atom_id") + atom2Col = connectData.getAttributeIndex("ptnr2_label_atom_id") + asym1Col = connectData.getAttributeIndex("ptnr1_label_asym_id") + asym2Col = connectData.getAttributeIndex("ptnr2_label_asym_id") + typeCol = connectData.getAttributeIndex("conn_type_id") connectBonds = [] for row in connectData.getRowList(): type = row[typeCol][:6] - if type in ('covale', 'disulf', 'modres'): + if type in ("covale", "disulf", "modres"): key1 = (row[res1Col], row[asym1Col], row[atom1Col]) key2 = (row[res2Col], row[asym2Col], row[atom2Col]) if key1 in atomTable and key2 in atomTable: @@ -239,15 +258,17 @@ def getPositions(self, asnp=False, frame=0): """ if asnp: if self._npPositions is None: - self._npPositions = [None]*len(self._positions) + self._npPositions = [None] * len(self._positions) if self._npPositions[frame] is None: - self._npPositions[frame] = Quantity(np.array(self._positions[frame].value_in_unit(nanometers)), nanometers) + self._npPositions[frame] = Quantity( + np.array(self._positions[frame].value_in_unit(nanometers)), + nanometers, + ) return self._npPositions[frame] return self._positions[frame] @staticmethod - def writeFile(topology, positions, file=sys.stdout, keepIds=False, - entry=None): + def writeFile(topology, positions, file=sys.stdout, keepIds=False, entry=None): """Write a PDBx/mmCIF file containing a single model. Parameters @@ -288,46 +309,54 @@ def writeHeader(topology, file=sys.stdout, entry=None, keepIds=False): PDBx/mmCIF format. Otherwise, the output file will be invalid. """ if entry is not None: - print('data_%s' % entry, file=file) + print("data_%s" % entry, file=file) else: - print('data_cell', file=file) + print("data_cell", file=file) print("# Created with OpenMM %s" % (str(date.today())), file=file) - print('#', file=file) + print("#", file=file) vectors = topology.getPeriodicBoxVectors() if vectors is not None: a, b, c, alpha, beta, gamma = computeLengthsAndAngles(vectors) - RAD_TO_DEG = 180/math.pi - print('_cell.length_a %10.4f' % (a*10), file=file) - print('_cell.length_b %10.4f' % (b*10), file=file) - print('_cell.length_c %10.4f' % (c*10), file=file) - print('_cell.angle_alpha %10.4f' % (alpha*RAD_TO_DEG), file=file) - print('_cell.angle_beta %10.4f' % (beta*RAD_TO_DEG), file=file) - print('_cell.angle_gamma %10.4f' % (gamma*RAD_TO_DEG), file=file) - print('#', file=file) + RAD_TO_DEG = 180 / math.pi + print("_cell.length_a %10.4f" % (a * 10), file=file) + print("_cell.length_b %10.4f" % (b * 10), file=file) + print("_cell.length_c %10.4f" % (c * 10), file=file) + print("_cell.angle_alpha %10.4f" % (alpha * RAD_TO_DEG), file=file) + print("_cell.angle_beta %10.4f" % (beta * RAD_TO_DEG), file=file) + print("_cell.angle_gamma %10.4f" % (gamma * RAD_TO_DEG), file=file) + print("#", file=file) # Identify bonds that should be listed in the file. bonds = [] for atom1, atom2 in topology.bonds(): - if atom1.residue.name not in PDBFile._standardResidues or atom2.residue.name not in PDBFile._standardResidues: + if ( + atom1.residue.name not in PDBFile._standardResidues + or atom2.residue.name not in PDBFile._standardResidues + ): bonds.append((atom1, atom2)) - elif atom1.name == 'SG' and atom2.name == 'SG' and atom1.residue.name == 'CYS' and atom2.residue.name == 'CYS': + elif ( + atom1.name == "SG" + and atom2.name == "SG" + and atom1.residue.name == "CYS" + and atom2.residue.name == "CYS" + ): bonds.append((atom1, atom2)) if len(bonds) > 0: # Write the bond information. - print('loop_', file=file) - print('_struct_conn.id', file=file) - print('_struct_conn.conn_type_id', file=file) - print('_struct_conn.ptnr1_label_asym_id', file=file) - print('_struct_conn.ptnr1_label_comp_id', file=file) - print('_struct_conn.ptnr1_label_seq_id', file=file) - print('_struct_conn.ptnr1_label_atom_id', file=file) - print('_struct_conn.ptnr2_label_asym_id', file=file) - print('_struct_conn.ptnr2_label_comp_id', file=file) - print('_struct_conn.ptnr2_label_seq_id', file=file) - print('_struct_conn.ptnr2_label_atom_id', file=file) + print("loop_", file=file) + print("_struct_conn.id", file=file) + print("_struct_conn.conn_type_id", file=file) + print("_struct_conn.ptnr1_label_asym_id", file=file) + print("_struct_conn.ptnr1_label_comp_id", file=file) + print("_struct_conn.ptnr1_label_seq_id", file=file) + print("_struct_conn.ptnr1_label_atom_id", file=file) + print("_struct_conn.ptnr2_label_asym_id", file=file) + print("_struct_conn.ptnr2_label_comp_id", file=file) + print("_struct_conn.ptnr2_label_seq_id", file=file) + print("_struct_conn.ptnr2_label_atom_id", file=file) chainIds = {} resIds = {} if keepIds: @@ -336,49 +365,63 @@ def writeHeader(topology, file=sys.stdout, entry=None, keepIds=False): for res in topology.residues(): resIds[res] = res.id else: - for (chainIndex, chain) in enumerate(topology.chains()): - chainIds[chain] = chr(ord('A')+chainIndex%26) - for (resIndex, res) in enumerate(chain.residues()): - resIds[res] = resIndex+1 + for chainIndex, chain in enumerate(topology.chains()): + chainIds[chain] = chr(ord("A") + chainIndex % 26) + for resIndex, res in enumerate(chain.residues()): + resIds[res] = resIndex + 1 for i, (atom1, atom2) in enumerate(bonds): - if atom1.residue.name == 'CYS' and atom2.residue.name == 'CYS': - bondType = 'disulf' + if atom1.residue.name == "CYS" and atom2.residue.name == "CYS": + bondType = "disulf" else: - bondType = 'covale' + bondType = "covale" line = "bond%d %s %s %-4s %5s %-4s %s %-4s %5s %-4s" - print(line % (i+1, bondType, chainIds[atom1.residue.chain], atom1.residue.name, resIds[atom1.residue], atom1.name, - chainIds[atom2.residue.chain], atom2.residue.name, resIds[atom2.residue], atom2.name), file=file) - print('#', file=file) + print( + line + % ( + i + 1, + bondType, + chainIds[atom1.residue.chain], + atom1.residue.name, + resIds[atom1.residue], + atom1.name, + chainIds[atom2.residue.chain], + atom2.residue.name, + resIds[atom2.residue], + atom2.name, + ), + file=file, + ) + print("#", file=file) # Write the header for the atom coordinates. - print('loop_', file=file) - print('_atom_site.group_PDB', file=file) - print('_atom_site.id', file=file) - print('_atom_site.type_symbol', file=file) - print('_atom_site.label_atom_id', file=file) - print('_atom_site.label_alt_id', file=file) - print('_atom_site.label_comp_id', file=file) - print('_atom_site.label_asym_id', file=file) - print('_atom_site.label_entity_id', file=file) - print('_atom_site.label_seq_id', file=file) - print('_atom_site.pdbx_PDB_ins_code', file=file) - print('_atom_site.Cartn_x', file=file) - print('_atom_site.Cartn_y', file=file) - print('_atom_site.Cartn_z', file=file) - print('_atom_site.occupancy', file=file) - print('_atom_site.B_iso_or_equiv', file=file) - print('_atom_site.Cartn_x_esd', file=file) - print('_atom_site.Cartn_y_esd', file=file) - print('_atom_site.Cartn_z_esd', file=file) - print('_atom_site.occupancy_esd', file=file) - print('_atom_site.B_iso_or_equiv_esd', file=file) - print('_atom_site.pdbx_formal_charge', file=file) - print('_atom_site.auth_seq_id', file=file) - print('_atom_site.auth_comp_id', file=file) - print('_atom_site.auth_asym_id', file=file) - print('_atom_site.auth_atom_id', file=file) - print('_atom_site.pdbx_PDB_model_num', file=file) + print("loop_", file=file) + print("_atom_site.group_PDB", file=file) + print("_atom_site.id", file=file) + print("_atom_site.type_symbol", file=file) + print("_atom_site.label_atom_id", file=file) + print("_atom_site.label_alt_id", file=file) + print("_atom_site.label_comp_id", file=file) + print("_atom_site.label_asym_id", file=file) + print("_atom_site.label_entity_id", file=file) + print("_atom_site.label_seq_id", file=file) + print("_atom_site.pdbx_PDB_ins_code", file=file) + print("_atom_site.Cartn_x", file=file) + print("_atom_site.Cartn_y", file=file) + print("_atom_site.Cartn_z", file=file) + print("_atom_site.occupancy", file=file) + print("_atom_site.B_iso_or_equiv", file=file) + print("_atom_site.Cartn_x_esd", file=file) + print("_atom_site.Cartn_y_esd", file=file) + print("_atom_site.Cartn_z_esd", file=file) + print("_atom_site.occupancy_esd", file=file) + print("_atom_site.B_iso_or_equiv_esd", file=file) + print("_atom_site.pdbx_formal_charge", file=file) + print("_atom_site.auth_seq_id", file=file) + print("_atom_site.auth_comp_id", file=file) + print("_atom_site.auth_asym_id", file=file) + print("_atom_site.auth_atom_id", file=file) + print("_atom_site.pdbx_PDB_model_num", file=file) @staticmethod def writeModel(topology, positions, file=sys.stdout, modelIndex=1, keepIds=False): @@ -401,30 +444,30 @@ def writeModel(topology, positions, file=sys.stdout, modelIndex=1, keepIds=False PDBx/mmCIF format. Otherwise, the output file will be invalid. """ if len(list(topology.atoms())) != len(positions): - raise ValueError('The number of positions must match the number of atoms') + raise ValueError("The number of positions must match the number of atoms") if is_quantity(positions): positions = positions.value_in_unit(angstroms) if any(math.isnan(norm(pos)) for pos in positions): - raise ValueError('Particle position is NaN') + raise ValueError("Particle position is NaN") if any(math.isinf(norm(pos)) for pos in positions): - raise ValueError('Particle position is infinite') + raise ValueError("Particle position is infinite") nonHeterogens = PDBFile._standardResidues[:] - nonHeterogens.remove('HOH') + nonHeterogens.remove("HOH") atomIndex = 1 posIndex = 0 - for (chainIndex, chain) in enumerate(topology.chains()): + for chainIndex, chain in enumerate(topology.chains()): if keepIds: chainName = chain.id else: - chainName = chr(ord('A')+chainIndex%26) + chainName = chr(ord("A") + chainIndex % 26) residues = list(chain.residues()) - for (resIndex, res) in enumerate(residues): + for resIndex, res in enumerate(residues): if keepIds: resId = res.id - resIC = (res.insertionCode if res.insertionCode.strip() else '.') + resIC = res.insertionCode if res.insertionCode.strip() else "." else: resId = resIndex + 1 - resIC = '.' + resIC = "." if res.name in nonHeterogens: recordName = "ATOM" else: @@ -434,9 +477,29 @@ def writeModel(topology, positions, file=sys.stdout, modelIndex=1, keepIds=False if atom.element is not None: symbol = atom.element.symbol else: - symbol = '?' + symbol = "?" line = "%s %5d %-3s %-4s . %-4s %s ? %5s %s %10.4f %10.4f %10.4f 0.0 0.0 ? ? ? ? ? . %5s %4s %s %4s %5d" - print(line % (recordName, atomIndex, symbol, atom.name, res.name, chainName, resId, resIC, coords[0], coords[1], coords[2], - resId, res.name, chainName, atom.name, modelIndex), file=file) + print( + line + % ( + recordName, + atomIndex, + symbol, + atom.name, + res.name, + chainName, + resId, + resIC, + coords[0], + coords[1], + coords[2], + resId, + res.name, + chainName, + atom.name, + modelIndex, + ), + file=file, + ) posIndex += 1 atomIndex += 1 diff --git a/gufe/vendor/pdb_file/singelton.py b/gufe/vendor/pdb_file/singelton.py index 0fb0b54c..fb7b7f4b 100644 --- a/gufe/vendor/pdb_file/singelton.py +++ b/gufe/vendor/pdb_file/singelton.py @@ -3,13 +3,15 @@ maintains the correctness of instance is instance even following pickling/unpickling """ -class Singleton(object): + + +class Singleton: _inst = None + def __new__(cls): if cls._inst is None: - cls._inst = super(Singleton, cls).__new__(cls) + cls._inst = super().__new__(cls) return cls._inst def __reduce__(self): return repr(self) - diff --git a/gufe/vendor/pdb_file/topology.py b/gufe/vendor/pdb_file/topology.py index d83fa6da..5d5e0d88 100644 --- a/gufe/vendor/pdb_file/topology.py +++ b/gufe/vendor/pdb_file/topology.py @@ -28,50 +28,64 @@ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from __future__ import absolute_import + __author__ = "Peter Eastman" __version__ = "1.0" import os - -import numpy as np -from copy import deepcopy +import xml.etree.ElementTree as etree from collections import namedtuple +from copy import deepcopy -import xml.etree.ElementTree as etree +import numpy as np +from openmm.unit import is_quantity, nanometers, sqrt from .singelton import Singleton -from openmm.unit import nanometers, sqrt, is_quantity - # Enumerated values for bond type + class Single(Singleton): def __repr__(self): - return 'Single' + return "Single" + + Single = Single() + class Double(Singleton): def __repr__(self): - return 'Double' + return "Double" + + Double = Double() + class Triple(Singleton): def __repr__(self): - return 'Triple' + return "Triple" + + Triple = Triple() + class Aromatic(Singleton): def __repr__(self): - return 'Aromatic' + return "Aromatic" + + Aromatic = Aromatic() + class Amide(Singleton): def __repr__(self): - return 'Amide' + return "Amide" + + Amide = Amide() -class Topology(object): + +class Topology: """Topology stores the topological information about a system. The structure of a Topology object is similar to that of a PDB file. It consists of a set of Chains @@ -98,27 +112,28 @@ def __repr__(self): nres = self._numResidues natom = self._numAtoms nbond = len(self._bonds) - return '<%s; %d chains, %d residues, %d atoms, %d bonds>' % ( - type(self).__name__, nchains, nres, natom, nbond) + return "<%s; %d chains, %d residues, %d atoms, %d bonds>" % ( + type(self).__name__, + nchains, + nres, + natom, + nbond, + ) def getNumAtoms(self): - """Return the number of atoms in the Topology. - """ + """Return the number of atoms in the Topology.""" return self._numAtoms def getNumResidues(self): - """Return the number of residues in the Topology. - """ + """Return the number of residues in the Topology.""" return self._numResidues def getNumChains(self): - """Return the number of chains in the Topology. - """ + """Return the number of chains in the Topology.""" return len(self._chains) def getNumBonds(self): - """Return the number of bonds in the Topology. - """ + """Return the number of bonds in the Topology.""" return len(self._bonds) def addChain(self, id=None): @@ -136,12 +151,12 @@ def addChain(self, id=None): the newly created Chain """ if id is None: - id = str(len(self._chains)+1) + id = str(len(self._chains) + 1) chain = Chain(len(self._chains), self, id) self._chains.append(chain) return chain - def addResidue(self, name, chain, id=None, insertionCode=''): + def addResidue(self, name, chain, id=None, insertionCode=""): """Create a new Residue and add it to the Topology. Parameters @@ -161,10 +176,10 @@ def addResidue(self, name, chain, id=None, insertionCode=''): Residue the newly created Residue """ - if len(chain._residues) > 0 and self._numResidues != chain._residues[-1].index+1: - raise ValueError('All residues within a chain must be contiguous') + if len(chain._residues) > 0 and self._numResidues != chain._residues[-1].index + 1: + raise ValueError("All residues within a chain must be contiguous") if id is None: - id = str(self._numResidues+1) + id = str(self._numResidues + 1) residue = Residue(name, self._numResidues, chain, id, insertionCode) self._numResidues += 1 chain._residues.append(residue) @@ -190,10 +205,10 @@ def addAtom(self, name, element, residue, id=None): Atom the newly created Atom """ - if len(residue._atoms) > 0 and self._numAtoms != residue._atoms[-1].index+1: - raise ValueError('All atoms within a residue must be contiguous') + if len(residue._atoms) > 0 and self._numAtoms != residue._atoms[-1].index + 1: + raise ValueError("All atoms within a residue must be contiguous") if id is None: - id = str(self._numAtoms+1) + id = str(self._numAtoms + 1) atom = Atom(name, element, self._numAtoms, residue, id) self._numAtoms += 1 residue._atoms.append(atom) @@ -223,15 +238,13 @@ def chains(self): def residues(self): """Iterate over all Residues in the Topology.""" for chain in self._chains: - for residue in chain._residues: - yield residue + yield from chain._residues def atoms(self): """Iterate over all Atoms in the Topology.""" for chain in self._chains: for residue in chain._residues: - for atom in residue._atoms: - yield atom + yield from residue._atoms def bonds(self): """Iterate over all bonds (each represented as a tuple of two Atoms) in the Topology.""" @@ -240,20 +253,28 @@ def bonds(self): def getPeriodicBoxVectors(self): """Get the vectors defining the periodic box. - The return value may be None if this Topology does not represent a periodic structure.""" + The return value may be None if this Topology does not represent a periodic structure. + """ return self._periodicBoxVectors def setPeriodicBoxVectors(self, vectors): """Set the vectors defining the periodic box.""" if vectors is not None: if not is_quantity(vectors[0][0]): - vectors = vectors*nanometers - if vectors[0][1] != 0*nanometers or vectors[0][2] != 0*nanometers: - raise ValueError("First periodic box vector must be parallel to x."); - if vectors[1][2] != 0*nanometers: - raise ValueError("Second periodic box vector must be in the x-y plane."); - if vectors[0][0] <= 0*nanometers or vectors[1][1] <= 0*nanometers or vectors[2][2] <= 0*nanometers or vectors[0][0] < 2*abs(vectors[1][0]) or vectors[0][0] < 2*abs(vectors[2][0]) or vectors[1][1] < 2*abs(vectors[2][1]): - raise ValueError("Periodic box vectors must be in reduced form."); + vectors = vectors * nanometers + if vectors[0][1] != 0 * nanometers or vectors[0][2] != 0 * nanometers: + raise ValueError("First periodic box vector must be parallel to x.") + if vectors[1][2] != 0 * nanometers: + raise ValueError("Second periodic box vector must be in the x-y plane.") + if ( + vectors[0][0] <= 0 * nanometers + or vectors[1][1] <= 0 * nanometers + or vectors[2][2] <= 0 * nanometers + or vectors[0][0] < 2 * abs(vectors[1][0]) + or vectors[0][0] < 2 * abs(vectors[2][0]) + or vectors[1][1] < 2 * abs(vectors[2][1]) + ): + raise ValueError("Periodic box vectors must be in reduced form.") self._periodicBoxVectors = deepcopy(vectors) def getUnitCellDimensions(self): @@ -266,19 +287,24 @@ def getUnitCellDimensions(self): xsize = self._periodicBoxVectors[0][0].value_in_unit(nanometers) ysize = self._periodicBoxVectors[1][1].value_in_unit(nanometers) zsize = self._periodicBoxVectors[2][2].value_in_unit(nanometers) - return np.array([xsize, ysize, zsize])*nanometers + return np.array([xsize, ysize, zsize]) * nanometers def setUnitCellDimensions(self, dimensions): """Set the dimensions of the crystallographic unit cell. This method is an alternative to setPeriodicBoxVectors() for the case of a rectangular box. It sets - the box vectors to be orthogonal to each other and to have the specified lengths.""" + the box vectors to be orthogonal to each other and to have the specified lengths. + """ if dimensions is None: self._periodicBoxVectors = None else: if is_quantity(dimensions): dimensions = dimensions.value_in_unit(nanometers) - self._periodicBoxVectors = (np.array([dimensions[0], 0, 0]), np.array([0, dimensions[1], 0]), np.array([0, 0, dimensions[2]]))*nanometers + self._periodicBoxVectors = ( + np.array([dimensions[0], 0, 0]), + np.array([0, dimensions[1], 0]), + np.array([0, 0, dimensions[2]]), + ) * nanometers @staticmethod def loadBondDefinitions(file): @@ -291,12 +317,18 @@ def loadBondDefinitions(file): will be used for any PDB file loaded after this is called. """ tree = etree.parse(file) - for residue in tree.getroot().findall('Residue'): + for residue in tree.getroot().findall("Residue"): bonds = [] - Topology._standardBonds[residue.attrib['name']] = bonds - for bond in residue.findall('Bond'): - bonds.append((bond.attrib['from'], bond.attrib['to'], bond.attrib['type'], int( - bond.attrib['order']))) + Topology._standardBonds[residue.attrib["name"]] = bonds + for bond in residue.findall("Bond"): + bonds.append( + ( + bond.attrib["from"], + bond.attrib["to"], + bond.attrib["type"], + int(bond.attrib["order"]), + ) + ) def createStandardBonds(self): """Create bonds based on the atom and residue names for all standard residue types. @@ -307,7 +339,7 @@ def createStandardBonds(self): if not Topology._hasLoadedStandardBonds: # Load the standard bond definitions. - Topology.loadBondDefinitions(os.path.join(os.path.dirname(__file__), 'data', 'residues.xml')) + Topology.loadBondDefinitions(os.path.join(os.path.dirname(__file__), "data", "residues.xml")) Topology._hasLoadedStandardBonds = True for chain in self._chains: # First build a map of atom names to atoms. @@ -325,52 +357,57 @@ def createStandardBonds(self): name = chain._residues[i].name if name in Topology._standardBonds: for bond in Topology._standardBonds[name]: - if bond[0].startswith('-') and i > 0: - fromResidue = i-1 + if bond[0].startswith("-") and i > 0: + fromResidue = i - 1 fromAtom = bond[0][1:] - elif bond[0].startswith('+') and i 0: - toResidue = i-1 + if bond[1].startswith("-") and i > 0: + toResidue = i - 1 toAtom = bond[1][1:] - elif bond[1].startswith('+') and i ND1=CE1-NE2-HE2 - avoid "charged" resonance structure - bond_atoms = (fromAtom, toAtom) - if(name == "HIS" and "CE1" in bond_atoms and any([N in bond_atoms for N in ["ND1", "NE2"]])): + bond_atoms = (fromAtom, toAtom) + if name == "HIS" and "CE1" in bond_atoms and any([N in bond_atoms for N in ["ND1", "NE2"]]): atoms = atomMaps[i] ND1_protonated = "HD1" in atoms NE2_protonated = "HE2" in atoms - - if(ND1_protonated and not NE2_protonated): # HD1-ND1-CE1=ND2 - if("ND1" in bond_atoms): - bond_order = 1 + + if ND1_protonated and not NE2_protonated: # HD1-ND1-CE1=ND2 + if "ND1" in bond_atoms: + bond_order = 1 else: - bond_order = 2 - elif(not ND1_protonated and NE2_protonated): # ND1=CE1-NE2-HE2 - if("ND1" in bond_atoms): - bond_order = 2 + bond_order = 2 + elif not ND1_protonated and NE2_protonated: # ND1=CE1-NE2-HE2 + if "ND1" in bond_atoms: + bond_order = 2 else: - bond_order = 1 - else: # does not matter if doubly or none protonated. + bond_order = 1 + else: # does not matter if doubly or none protonated. pass - self.addBond(atomMaps[fromResidue][fromAtom], atomMaps[toResidue][toAtom], type=bond_type, order=bond_order) + self.addBond( + atomMaps[fromResidue][fromAtom], + atomMaps[toResidue][toAtom], + type=bond_type, + order=bond_order, + ) def createDisulfideBonds(self, positions): """Identify disulfide bonds based on proximity and add them to the @@ -381,30 +418,31 @@ def createDisulfideBonds(self, positions): positions : list The list of atomic positions based on which to identify bonded atoms """ + def isCyx(res): names = [atom.name for atom in res._atoms] - return 'SG' in names and 'HG' not in names + return "SG" in names and "HG" not in names + # This function is used to prevent multiple di-sulfide bonds from being # assigned to a given atom. def isDisulfideBonded(atom): - for b in self._bonds: - if (atom in b and b[0].name == 'SG' and - b[1].name == 'SG'): - return True + for b in self._bonds: + if atom in b and b[0].name == "SG" and b[1].name == "SG": + return True - return False + return False - cyx = [res for res in self.residues() if res.name == 'CYS' and isCyx(res)] + cyx = [res for res in self.residues() if res.name == "CYS" and isCyx(res)] atomNames = [[atom.name for atom in res._atoms] for res in cyx] for i in range(len(cyx)): - sg1 = cyx[i]._atoms[atomNames[i].index('SG')] + sg1 = cyx[i]._atoms[atomNames[i].index("SG")] pos1 = positions[sg1.index] - candidate_distance, candidate_atom = 0.3*nanometers, None + candidate_distance, candidate_atom = 0.3 * nanometers, None for j in range(i): - sg2 = cyx[j]._atoms[atomNames[j].index('SG')] + sg2 = cyx[j]._atoms[atomNames[j].index("SG")] pos2 = positions[sg2.index] - delta = [x-y for (x,y) in zip(pos1, pos2)] - distance = sqrt(delta[0]*delta[0] + delta[1]*delta[1] + delta[2]*delta[2]) + delta = [x - y for (x, y) in zip(pos1, pos2)] + distance = sqrt(delta[0] * delta[0] + delta[1] * delta[1] + delta[2] * delta[2]) if distance < candidate_distance and not isDisulfideBonded(sg2): candidate_distance = distance candidate_atom = sg2 @@ -412,8 +450,10 @@ def isDisulfideBonded(atom): if candidate_atom: self.addBond(sg1, candidate_atom, type="Single", order=1) -class Chain(object): + +class Chain: """A Chain object represents a chain within a Topology.""" + def __init__(self, index, topology, id): """Construct a new Chain. You should call addChain() on the Topology instead of calling this directly.""" ## The index of the Chain within its Topology @@ -431,8 +471,7 @@ def residues(self): def atoms(self): """Iterate over all Atoms in the Chain.""" for residue in self._residues: - for atom in residue._atoms: - yield atom + yield from residue._atoms def __len__(self): return len(self._residues) @@ -440,8 +479,10 @@ def __len__(self): def __repr__(self): return "" % self.index -class Residue(object): + +class Residue: """A Residue object represents a residue within a Topology.""" + def __init__(self, name, index, chain, id, insertionCode): """Construct a new Residue. You should call addResidue() on the Topology instead of calling this directly.""" ## The name of the Residue @@ -462,23 +503,28 @@ def atoms(self): def bonds(self): """Iterate over all Bonds involving any atom in this residue.""" - return ( bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) or (bond[1] in self._atoms)) ) + return (bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) or (bond[1] in self._atoms))) def internal_bonds(self): """Iterate over all internal Bonds.""" - return ( bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) and (bond[1] in self._atoms)) ) + return (bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) and (bond[1] in self._atoms))) def external_bonds(self): """Iterate over all Bonds to external atoms.""" - return ( bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) != (bond[1] in self._atoms)) ) + return (bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) != (bond[1] in self._atoms))) def __len__(self): return len(self._atoms) def __repr__(self): - return "" % (self.index, self.name, self.chain.index) + return "" % ( + self.index, + self.name, + self.chain.index, + ) + -class Atom(object): +class Atom: """An Atom object represents an atom within a Topology.""" def __init__(self, name, element, index, residue, id): @@ -495,17 +541,25 @@ def __init__(self, name, element, index, residue, id): self.id = id def __repr__(self): - return "" % (self.index, self.name, self.residue.chain.index, self.residue.index, self.residue.name) + return "" % ( + self.index, + self.name, + self.residue.chain.index, + self.residue.index, + self.residue.name, + ) -class Bond(namedtuple('Bond', ['atom1', 'atom2'])): + +class Bond(namedtuple("Bond", ["atom1", "atom2"])): """A Bond object represents a bond between two Atoms within a Topology. This class extends tuple, and may be interpreted as a 2 element tuple of Atom objects. - It also has fields that can optionally be used to describe the bond order and type of bond.""" + It also has fields that can optionally be used to describe the bond order and type of bond. + """ def __new__(cls, atom1, atom2, type=None, order=None): """Create a new Bond. You should call addBond() on the Topology instead of calling this directly.""" - bond = super(Bond, cls).__new__(cls, atom1, atom2) + bond = super().__new__(cls, atom1, atom2) bond.type = type bond.order = order return bond @@ -526,9 +580,9 @@ def __deepcopy__(self, memo): return Bond(self[0], self[1], self.type, self.order) def __repr__(self): - s = "Bond(%s, %s" % (self[0], self[1]) + s = f"Bond({self[0]}, {self[1]}" if self.type is not None: - s = "%s, type=%s" % (s, self.type) + s = f"{s}, type={self.type}" if self.order is not None: s = "%s, order=%d" % (s, self.order) s += ")" diff --git a/gufe/vendor/pdb_file/unitcell.py b/gufe/vendor/pdb_file/unitcell.py index e4068707..bc65cec7 100644 --- a/gufe/vendor/pdb_file/unitcell.py +++ b/gufe/vendor/pdb_file/unitcell.py @@ -28,36 +28,43 @@ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from __future__ import absolute_import + __author__ = "Peter Eastman" __version__ = "1.0" -import numpy as np -from openmm.unit import nanometers, is_quantity, norm, dot, radians import math +import numpy as np +from openmm.unit import dot, is_quantity, nanometers, norm, radians + def computePeriodicBoxVectors(a_length, b_length, c_length, alpha, beta, gamma): """Convert lengths and angles to periodic box vectors. - + Lengths should be given in nanometers and angles in radians (or as Quantity instances) """ - if is_quantity(a_length): a_length = a_length.value_in_unit(nanometers) - if is_quantity(b_length): b_length = b_length.value_in_unit(nanometers) - if is_quantity(c_length): c_length = c_length.value_in_unit(nanometers) - if is_quantity(alpha): alpha = alpha.value_in_unit(radians) - if is_quantity(beta): beta = beta.value_in_unit(radians) - if is_quantity(gamma): gamma = gamma.value_in_unit(radians) + if is_quantity(a_length): + a_length = a_length.value_in_unit(nanometers) + if is_quantity(b_length): + b_length = b_length.value_in_unit(nanometers) + if is_quantity(c_length): + c_length = c_length.value_in_unit(nanometers) + if is_quantity(alpha): + alpha = alpha.value_in_unit(radians) + if is_quantity(beta): + beta = beta.value_in_unit(radians) + if is_quantity(gamma): + gamma = gamma.value_in_unit(radians) # Compute the vectors. a = [a_length, 0, 0] - b = [b_length*math.cos(gamma), b_length*math.sin(gamma), 0] - cx = c_length*math.cos(beta) - cy = c_length*(math.cos(alpha)-math.cos(beta)*math.cos(gamma))/math.sin(gamma) - cz = math.sqrt(c_length*c_length-cx*cx-cy*cy) + b = [b_length * math.cos(gamma), b_length * math.sin(gamma), 0] + cx = c_length * math.cos(beta) + cy = c_length * (math.cos(alpha) - math.cos(beta) * math.cos(gamma)) / math.sin(gamma) + cz = math.sqrt(c_length * c_length - cx * cx - cy * cy) c = [cx, cy, cz] # If any elements are very close to 0, set them to exactly 0. @@ -75,13 +82,14 @@ def computePeriodicBoxVectors(a_length, b_length, c_length, alpha, beta, gamma): # Make sure they're in the reduced form required by OpenMM. - c = c - b*round(c[1]/b[1]) - c = c - a*round(c[0]/a[0]) - b = b - a*round(b[0]/a[0]) - return (a, b, c)*nanometers + c = c - b * round(c[1] / b[1]) + c = c - a * round(c[0] / a[0]) + b = b - a * round(b[0] / a[0]) + return (a, b, c) * nanometers + def reducePeriodicBoxVectors(periodicBoxVectors): - """ Reduces the representation of the PBC. periodicBoxVectors is expected to + """Reduces the representation of the PBC. periodicBoxVectors is expected to be an unpackable iterable of length-3 iterables """ if is_quantity(periodicBoxVectors): @@ -92,12 +100,13 @@ def reducePeriodicBoxVectors(periodicBoxVectors): b = np.array(b) c = np.array(c) - c = c - b*round(c[1]/b[1]) - c = c - a*round(c[0]/a[0]) - b = b - a*round(b[0]/a[0]) + c = c - b * round(c[1] / b[1]) + c = c - a * round(c[0] / a[0]) + b = b - a * round(b[0] / a[0]) return (a, b, c) * nanometers + def computeLengthsAndAngles(periodicBoxVectors): """Convert periodic box vectors to lengths and angles. @@ -110,7 +119,7 @@ def computeLengthsAndAngles(periodicBoxVectors): a_length = norm(a) b_length = norm(b) c_length = norm(c) - alpha = math.acos(dot(b, c)/(b_length*c_length)) - beta = math.acos(dot(c, a)/(c_length*a_length)) - gamma = math.acos(dot(a, b)/(a_length*b_length)) + alpha = math.acos(dot(b, c) / (b_length * c_length)) + beta = math.acos(dot(c, a) / (c_length * a_length)) + gamma = math.acos(dot(a, b) / (a_length * b_length)) return (a_length, b_length, c_length, alpha, beta, gamma) diff --git a/gufe/visualization/mapping_visualization.py b/gufe/visualization/mapping_visualization.py index 0b6c02be..ac110b8a 100644 --- a/gufe/visualization/mapping_visualization.py +++ b/gufe/visualization/mapping_visualization.py @@ -1,12 +1,11 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe -from typing import Any, Collection, Optional +from collections.abc import Collection from itertools import chain +from typing import Any, Optional from rdkit import Chem -from rdkit.Chem import Draw -from rdkit.Chem import AllChem - +from rdkit.Chem import AllChem, Draw # highlight core element changes differently from unique atoms # RGBA color value needs to be between 0 and 1, so divide by 255 @@ -14,9 +13,7 @@ BLUE = (0.0, 90 / 255, 181 / 255, 1.0) -def _match_elements( - mol1: Chem.Mol, idx1: int, mol2: Chem.Mol, idx2: int -) -> bool: +def _match_elements(mol1: Chem.Mol, idx1: int, mol2: Chem.Mol, idx2: int) -> bool: """ Convenience method to check if elements between two molecules (molA and molB) are the same. @@ -42,9 +39,7 @@ def _match_elements( return elem_mol1 == elem_mol2 -def _get_unique_bonds_and_atoms( - mapping: dict[int, int], mol1: Chem.Mol, mol2: Chem.Mol -) -> dict: +def _get_unique_bonds_and_atoms(mapping: dict[int, int], mol1: Chem.Mol, mol2: Chem.Mol) -> dict: """ Given an input mapping, returns new atoms, element changes, and involved bonds. @@ -143,13 +138,14 @@ def _draw_molecules( if d2d is None: # select default layout based on number of molecules - grid_x, grid_y = {1: (1, 1), 2: (2, 1), }[len(mols)] - d2d = Draw.rdMolDraw2D.MolDraw2DCairo( - grid_x * 300, grid_y * 300, 300, 300) + grid_x, grid_y = { + 1: (1, 1), + 2: (2, 1), + }[len(mols)] + d2d = Draw.rdMolDraw2D.MolDraw2DCairo(grid_x * 300, grid_y * 300, 300, 300) # get molecule name labels - labels = [m.GetProp("ofe-name") if(m.HasProp("ofe-name")) - else "" for m in mols] + labels = [m.GetProp("ofe-name") if (m.HasProp("ofe-name")) else "" for m in mols] # squash to 2D copies = [Chem.Mol(mol) for mol in mols] @@ -158,10 +154,7 @@ def _draw_molecules( # mol alignments if atom_mapping present for (i, j), atomMap in atom_mapping.items(): - AllChem.AlignMol( - copies[j], copies[i], - atomMap=[(k, v) for v, k in atomMap.items()] - ) + AllChem.AlignMol(copies[j], copies[i], atomMap=[(k, v) for v, k in atomMap.items()]) # standard settings for our visualization d2d.drawOptions().useBWAtomPalette() @@ -180,9 +173,7 @@ def _draw_molecules( return d2d.GetDrawingText() -def draw_mapping( - mol1_to_mol2: dict[int, int], mol1: Chem.Mol, mol2: Chem.Mol, d2d=None -): +def draw_mapping(mol1_to_mol2: dict[int, int], mol1: Chem.Mol, mol2: Chem.Mol, d2d=None): """ Method to visualise the atom map correspondence between two rdkit molecules given an input mapping. @@ -229,7 +220,6 @@ def draw_mapping( atom_colors = [at1_colors, at2_colors] bond_colors = [bd1_colors, bd2_colors] - return _draw_molecules( d2d, [mol1, mol2], @@ -270,14 +260,15 @@ def draw_one_molecule_mapping(mol1_to_mol2, mol1, mol2, d2d=None): atom_colors = [{at: BLUE for at in uniques["elements"]}] bond_colors = [{bd: BLUE for bd in uniques["bond_changes"]}] - return _draw_molecules(d2d, - [mol1], - atoms_list, - bonds_list, - atom_colors, - bond_colors, - RED, - ) + return _draw_molecules( + d2d, + [mol1], + atoms_list, + bonds_list, + atom_colors, + bond_colors, + RED, + ) def draw_unhighlighted_molecule(mol, d2d=None): diff --git a/pyproject.toml b/pyproject.toml index d68b3488..f7833013 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,26 @@ dirty = "{base_version}+{distance}.{vcs}{rev}.dirty" distance-dirty = "{base_version}+{distance}.{vcs}{rev}.dirty" [tool.versioningit.vcs] -method = "git" +method = "git" match = ["*"] default-tag = "0.0.0" + +[tool.black] +line-length = 120 + +[tool.isort] +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +line_length = 120 +profile = "black" +known_first_party = ["gufe"] + +[tool.interrogate] +fail-under = 0 +ignore-regex = ["^get$", "^mock_.*", ".*BaseClass.*"] +# possible values for verbose: 0 (minimal output), 1 (-v), 2 (-vv) +verbose = 2 +color = true +exclude = ["build", "docs"]