diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 00000000..e66d5736
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,2 @@
+# https://github.com/OpenFreeEnergy/gufe/pull/421 -- big auto format MMH
+d27100a5b7b303df155e2b6d7874883d105e24bf
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
new file mode 100644
index 00000000..d34f4787
--- /dev/null
+++ b/.github/pull_request_template.md
@@ -0,0 +1,22 @@
+
+
+
+
+
+Tips
+* Comment "pre-commit.ci autofix" to have pre-commit.ci atomically format your PR.
+ Since this will create a commit, it is best to make this comment when you are finished with your work.
+
+
+Checklist
+* [ ] Added a ``news`` entry
+
+## Developers certificate of origin
+- [ ] I certify that this contribution is covered by the MIT License [here](https://github.com/OpenFreeEnergy/openfe/blob/main/LICENSE) and the **Developer Certificate of Origin** at .
diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 1914b712..6df2ffd6 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -27,18 +27,15 @@ jobs:
strategy:
fail-fast: false
matrix:
- os: ['ubuntu-latest']
+ os: ['ubuntu-latest', macos-latest]
pydantic-version: [">1"]
python-version:
- "3.10"
- - "3.11"
+ - "3.11"
- "3.12"
include:
- # Note: pinned to macos-12
+ # Note: we still need to add support for macos-13
# see https://github.com/OpenFreeEnergy/openfe/issues/842
- - os: "macos-12"
- python-version: "3.11"
- pydantic-version: ">1"
- os: "ubuntu-latest"
python-version: "3.11"
pydantic-version: "<2"
diff --git a/.github/workflows/clean_cache.yaml b/.github/workflows/clean_cache.yaml
index c3db0a6f..e9a6c50d 100644
--- a/.github/workflows/clean_cache.yaml
+++ b/.github/workflows/clean_cache.yaml
@@ -11,18 +11,18 @@ jobs:
steps:
- name: Check out code
uses: actions/checkout@v3
-
+
- name: Cleanup
run: |
gh extension install actions/gh-actions-cache
-
+
REPO=${{ github.repository }}
BRANCH="refs/pull/${{ github.event.pull_request.number }}/merge"
echo "Fetching list of cache key"
cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH | cut -f 1 )
- ## Setting this to not fail the workflow while deleting cache keys.
+ ## Setting this to not fail the workflow while deleting cache keys.
set +e
echo "Deleting caches..."
for cacheKey in $cacheKeysForPR
diff --git a/.github/workflows/conda_cron.yaml b/.github/workflows/conda_cron.yaml
index 33d3a954..c4f7ec49 100644
--- a/.github/workflows/conda_cron.yaml
+++ b/.github/workflows/conda_cron.yaml
@@ -24,7 +24,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- os: ['ubuntu-latest', 'macos-14']
+ os: ['ubuntu-latest', 'macos-latest']
python-version:
- "3.10"
- "3.11"
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..7b469f8c
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,41 @@
+ci:
+ autoupdate_schedule: "quarterly"
+ # comment / label "pre-commit.ci autofix" to a pull request to manually trigger auto-fixing
+ autofix_prs: false
+repos:
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.5.0
+ hooks:
+ - id: check-added-large-files
+ - id: check-case-conflict
+ - id: check-executables-have-shebangs
+ - id: check-symlinks
+ - id: check-toml
+ - id: check-yaml
+ - id: debug-statements
+ - id: destroyed-symlinks
+ - id: end-of-file-fixer
+ exclude: '\.(graphml)$'
+ - id: trailing-whitespace
+ exclude: '\.(pdb|gro|top|sdf|xml|cif|graphml)$'
+- repo: https://github.com/psf/black-pre-commit-mirror
+ rev: 24.2.0
+ hooks:
+ - id: black
+ - id: black-jupyter
+- repo: https://github.com/PyCQA/isort
+ rev: 5.13.2
+ hooks:
+ - id: isort
+ args: ["--profile", "black", "--filter-files"]
+- repo: https://github.com/econchick/interrogate
+ rev: 1.5.0
+ hooks:
+ - id: interrogate
+ args: [--fail-under=28]
+ pass_filenames: false
+- repo: https://github.com/asottile/pyupgrade
+ rev: v3.15.0
+ hooks:
+ - id: pyupgrade
+ args: ["--py39-plus"]
diff --git a/README.md b/README.md
index 929185a4..5fc2c7e2 100644
--- a/README.md
+++ b/README.md
@@ -2,6 +2,7 @@
[![build](https://github.com/OpenFreeEnergy/gufe/actions/workflows/ci.yaml/badge.svg)](https://github.com/OpenFreeEnergy/gufe/actions/workflows/ci.yaml)
[![coverage](https://codecov.io/gh/OpenFreeEnergy/gufe/branch/main/graph/badge.svg)](https://codecov.io/gh/OpenFreeEnergy/gufe)
[![Documentation Status](https://readthedocs.org/projects/gufe/badge/?version=latest)](https://gufe.readthedocs.io/en/latest/?badge=latest)
+[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/OpenFreeEnergy/gufe/main.svg)](https://results.pre-commit.ci/latest/github/OpenFreeEnergy/gufe/main)
# gufe - Grand Unified Free Energy
diff --git a/devtools/raise-or-close-issue.py b/devtools/raise-or-close-issue.py
index ff9d7742..97406337 100644
--- a/devtools/raise-or-close-issue.py
+++ b/devtools/raise-or-close-issue.py
@@ -5,24 +5,24 @@
# - TITLE: A string title which the issue will have.
import os
-from github import Github
+from github import Github
if __name__ == "__main__":
- git = Github(os.environ['GITHUB_TOKEN'])
- status = os.environ['CI_OUTCOME']
- repo = git.get_repo('OpenFreeEnergy/gufe')
- title = os.environ['TITLE']
-
+ git = Github(os.environ["GITHUB_TOKEN"])
+ status = os.environ["CI_OUTCOME"]
+ repo = git.get_repo("OpenFreeEnergy/gufe")
+ title = os.environ["TITLE"]
+
target_issue = None
for issue in repo.get_issues():
if issue.title == title:
target_issue = issue
-
+
# Close any issues with given title if CI returned green
- if status == 'success':
+ if status == "success":
if target_issue is not None:
- target_issue.edit(state='closed')
+ target_issue.edit(state="closed")
else:
# Otherwise raise an issue
if target_issue is None:
diff --git a/docs/CHANGELOG.rst b/docs/CHANGELOG.rst
index a50a0eda..2659e23d 100644
--- a/docs/CHANGELOG.rst
+++ b/docs/CHANGELOG.rst
@@ -28,5 +28,3 @@ v1.1.0
* Fixed an issue where ProtocolDAG DAG order & keys were unstable /
non-deterministic between processes under some circumstances (PR #315).
* Fixed a bug where edge annotations were lost when converting a ``LigandNetwork`` to graphml, all JSON codec types are now supported.
-
-
diff --git a/docs/_static/.gitkeep b/docs/_static/.gitkeep
index 8b137891..e69de29b 100644
--- a/docs/_static/.gitkeep
+++ b/docs/_static/.gitkeep
@@ -1 +0,0 @@
-
diff --git a/docs/conf.py b/docs/conf.py
index 905e4840..62d8bbdb 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -12,14 +12,15 @@
#
import os
import sys
-sys.path.insert(0, os.path.abspath('../'))
+
+sys.path.insert(0, os.path.abspath("../"))
# -- Project information -----------------------------------------------------
-project = 'gufe'
-copyright = '2022, The OpenFE Development Team'
-author = 'The OpenFE Development Team'
+project = "gufe"
+copyright = "2022, The OpenFE Development Team"
+author = "The OpenFE Development Team"
# -- General configuration ---------------------------------------------------
@@ -28,14 +29,14 @@
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- 'sphinx.ext.autodoc',
- 'sphinx.ext.autosummary',
- 'sphinx.ext.napoleon',
- 'sphinxcontrib.autodoc_pydantic',
- 'sphinx.ext.intersphinx',
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.napoleon",
+ "sphinxcontrib.autodoc_pydantic",
+ "sphinx.ext.intersphinx",
]
-autoclass_content = 'both'
+autoclass_content = "both"
autodoc_default_options = {
"members": True,
@@ -47,20 +48,21 @@
autosummary_generate = True
intersphinx_mapping = {
- 'rdkit': ('https://www.rdkit.org/docs/', None),
+ "rdkit": ("https://www.rdkit.org/docs/", None),
}
# Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
+templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
-exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
-autodoc_mock_imports = ["openff.models",
- "rdkit",
- "networkx",
+autodoc_mock_imports = [
+ "openff.models",
+ "rdkit",
+ "networkx",
]
@@ -77,7 +79,7 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+html_static_path = ["_static"]
# replace macros
rst_prolog = """
diff --git a/docs/guide/overview.rst b/docs/guide/overview.rst
index 30b9fd5f..86086b63 100644
--- a/docs/guide/overview.rst
+++ b/docs/guide/overview.rst
@@ -25,7 +25,7 @@ Ligand network setup
GUFE defines a basic API for the common case of performing alchemical
transformations between small molecules, either for relative binding free
energies of relative hydration free energies. This handles how mappings
-between different molecules are defined for alchemical transformations,
+between different molecules are defined for alchemical transformations,
by defining both the :class:`.LigandAtomMapping` object that contains the
details of a specific mapping, and the :class:`.AtomMapper` abstract API for
an object that creates the mappings.
diff --git a/docs/guide/serialization.rst b/docs/guide/serialization.rst
index 3e81c9e9..c983233d 100644
--- a/docs/guide/serialization.rst
+++ b/docs/guide/serialization.rst
@@ -245,7 +245,7 @@ classmethod:
.. Using JSON codecs outside of JSON
.. ---------------------------------
-.. In a custom recursive storage scheme,
+.. In a custom recursive storage scheme,
.. TODO: DWHS wants to write something here that describes how to use the
codecs in your own non-JSON storage scheme. But this is complicated
diff --git a/gufe/__init__.py b/gufe/__init__.py
index 10e12d57..c943abf9 100644
--- a/gufe/__init__.py
+++ b/gufe/__init__.py
@@ -3,40 +3,21 @@
from importlib.metadata import version
-from . import tokenization
-
-from . import visualization
-
-from .components import (
- Component,
- SmallMoleculeComponent,
- ProteinComponent,
- SolventComponent
-)
-
+from . import tokenization, visualization
from .chemicalsystem import ChemicalSystem
-
-from .mapping import (
- ComponentMapping, # how individual Components relate
- AtomMapping, AtomMapper, # more specific to atom based components
- LigandAtomMapping,
-)
-
-from .settings import Settings
-
-from .protocols import (
- Context,
- Protocol, # description of a method
- ProtocolUnit, # the individual step within a method
- ProtocolDAG, # many Units forming a workflow
- ProtocolUnitResult, # the result of a single Unit
- ProtocolDAGResult, # the collected result of a DAG
- ProtocolResult, # potentially many DAGs together, giving an estimate
-)
-
-from .transformations import Transformation, NonTransformation
-
-from .network import AlchemicalNetwork
+from .components import Component, ProteinComponent, SmallMoleculeComponent, SolventComponent
from .ligandnetwork import LigandNetwork
+from .mapping import AtomMapper # more specific to atom based components
+from .mapping import ComponentMapping # how individual Components relate
+from .mapping import AtomMapping, LigandAtomMapping
+from .network import AlchemicalNetwork
+from .protocols import Protocol # description of a method
+from .protocols import ProtocolDAG # many Units forming a workflow
+from .protocols import ProtocolDAGResult # the collected result of a DAG
+from .protocols import ProtocolUnit # the individual step within a method
+from .protocols import ProtocolUnitResult # the result of a single Unit
+from .protocols import Context, ProtocolResult # potentially many DAGs together, giving an estimate
+from .settings import Settings
+from .transformations import NonTransformation, Transformation
__version__ = version("gufe")
diff --git a/gufe/chemicalsystem.py b/gufe/chemicalsystem.py
index dedc5619..c4eb4add 100644
--- a/gufe/chemicalsystem.py
+++ b/gufe/chemicalsystem.py
@@ -4,8 +4,8 @@
from collections import abc
from typing import Optional
-from .tokenization import GufeTokenizable
from .components import Component
+from .tokenization import GufeTokenizable
class ChemicalSystem(GufeTokenizable, abc.Mapping):
@@ -14,14 +14,14 @@ def __init__(
components: dict[str, Component],
name: Optional[str] = "",
):
- """A combination of Components that form a system
+ r"""A combination of Components that form a system
Containing a combination of :class:`.SmallMoleculeComponent`,
:class:`.SolventComponent` and :class:`.ProteinComponent`, this object
typically represents all the molecules in a simulation box.
Used as a node for an :class:`.AlchemicalNetwork`.
-
+
Parameters
----------
components
@@ -40,24 +40,18 @@ def __init__(
self._name = name
def __repr__(self):
- return (
- f"{self.__class__.__name__}(name={self.name}, components={self.components})"
- )
+ return f"{self.__class__.__name__}(name={self.name}, components={self.components})"
def _to_dict(self):
return {
- "components": {
- key: value for key, value in sorted(self.components.items())
- },
+ "components": {key: value for key, value in sorted(self.components.items())},
"name": self.name,
}
@classmethod
def _from_dict(cls, d):
return cls(
- components={
- key: value for key, value in d["components"].items()
- },
+ components={key: value for key, value in d["components"].items()},
name=d["name"],
)
@@ -86,7 +80,7 @@ def name(self):
def total_charge(self):
"""Formal charge for the ChemicalSystem."""
# This might evaluate the property twice?
- #return sum(component.total_charge
+ # return sum(component.total_charge
# for component in self._components.values()
# if component.total_charge is not None)
total_charge = 0
diff --git a/gufe/components/__init__.py b/gufe/components/__init__.py
index 6713db21..0f9a69ea 100644
--- a/gufe/components/__init__.py
+++ b/gufe/components/__init__.py
@@ -1,7 +1,6 @@
"""The building blocks for defining systems"""
from .component import Component
-
-from .smallmoleculecomponent import SmallMoleculeComponent
from .proteincomponent import ProteinComponent
+from .smallmoleculecomponent import SmallMoleculeComponent
from .solventcomponent import SolventComponent
diff --git a/gufe/components/explicitmoleculecomponent.py b/gufe/components/explicitmoleculecomponent.py
index a3a513e5..88f1f196 100644
--- a/gufe/components/explicitmoleculecomponent.py
+++ b/gufe/components/explicitmoleculecomponent.py
@@ -1,13 +1,13 @@
import json
-import numpy as np
import warnings
-from rdkit import Chem
from typing import Optional
-from .component import Component
+import numpy as np
+from rdkit import Chem
# typing
from ..custom_typing import RDKitMol
+from .component import Component
def _ensure_ofe_name(mol: RDKitMol, name: str) -> str:
@@ -26,10 +26,7 @@ def _ensure_ofe_name(mol: RDKitMol, name: str) -> str:
pass
if name and rdkit_name and rdkit_name != name:
- warnings.warn(
- f"Component being renamed from {rdkit_name}"
- f"to {name}."
- )
+ warnings.warn(f"Component being renamed from {rdkit_name}" f"to {name}.")
elif name == "":
name = rdkit_name
@@ -56,21 +53,19 @@ def _check_partial_charges(mol: RDKitMol) -> None:
* If partial charges are found.
* If the partial charges are near 0 for all atoms.
"""
- if 'atom.dprop.PartialCharge' not in mol.GetPropNames():
+ if "atom.dprop.PartialCharge" not in mol.GetPropNames():
return
- p_chgs = np.array(
- mol.GetProp('atom.dprop.PartialCharge').split(), dtype=float
- )
+ p_chgs = np.array(mol.GetProp("atom.dprop.PartialCharge").split(), dtype=float)
if len(p_chgs) != mol.GetNumAtoms():
- errmsg = (f"Incorrect number of partial charges: {len(p_chgs)} "
- f" were provided for {mol.GetNumAtoms()} atoms")
+ errmsg = f"Incorrect number of partial charges: {len(p_chgs)} " f" were provided for {mol.GetNumAtoms()} atoms"
raise ValueError(errmsg)
if (sum(p_chgs) - Chem.GetFormalCharge(mol)) > 0.01:
- errmsg = (f"Sum of partial charges {sum(p_chgs)} differs from "
- f"RDKit formal charge {Chem.GetFormalCharge(mol)}")
+ errmsg = (
+ f"Sum of partial charges {sum(p_chgs)} differs from " f"RDKit formal charge {Chem.GetFormalCharge(mol)}"
+ )
raise ValueError(errmsg)
# set the charges on the atoms if not already set
@@ -81,18 +76,20 @@ def _check_partial_charges(mol: RDKitMol) -> None:
else:
atom_charge = atom.GetDoubleProp("PartialCharge")
if not np.isclose(atom_charge, charge):
- errmsg = (f"non-equivalent partial charges between atom and "
- f"molecule properties: {atom_charge} {charge}")
+ errmsg = (
+ f"non-equivalent partial charges between atom and " f"molecule properties: {atom_charge} {charge}"
+ )
raise ValueError(errmsg)
if np.all(np.isclose(p_chgs, 0.0)):
- wmsg = (f"Partial charges provided all equal to "
- "zero. These may be ignored by some Protocols.")
+ wmsg = f"Partial charges provided all equal to " "zero. These may be ignored by some Protocols."
warnings.warn(wmsg)
else:
- wmsg = ("Partial charges have been provided, these will "
- "preferentially be used instead of generating new "
- "partial charges")
+ wmsg = (
+ "Partial charges have been provided, these will "
+ "preferentially be used instead of generating new "
+ "partial charges"
+ )
warnings.warn(wmsg)
@@ -103,6 +100,7 @@ class ExplicitMoleculeComponent(Component):
representations. Specific file formats, such as SDF for small molecules
or PDB for proteins, should be implemented in subclasses.
"""
+
_rdkit: Chem.Mol
_name: str
@@ -115,14 +113,13 @@ def __init__(self, rdkit: RDKitMol, name: str = ""):
n_confs = len(conformers)
if n_confs > 1:
- warnings.warn(
- f"Molecule provided with {n_confs} conformers. "
- f"Only the first will be used."
- )
+ warnings.warn(f"Molecule provided with {n_confs} conformers. " f"Only the first will be used.")
if not any(atom.GetAtomicNum() == 1 for atom in rdkit.GetAtoms()):
- warnings.warn("Molecule doesn't have any hydrogen atoms present. "
- "If this is unexpected, consider loading the molecule with `removeHs=False`")
+ warnings.warn(
+ "Molecule doesn't have any hydrogen atoms present. "
+ "If this is unexpected, consider loading the molecule with `removeHs=False`"
+ )
self._rdkit = rdkit
self._smiles: Optional[str] = None
diff --git a/gufe/components/proteincomponent.py b/gufe/components/proteincomponent.py
index fce2f65c..a23ccf09 100644
--- a/gufe/components/proteincomponent.py
+++ b/gufe/components/proteincomponent.py
@@ -1,28 +1,24 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
import ast
-import json
import io
-import numpy as np
-from os import PathLike
+import json
import string
-from typing import Union, Optional
from collections import defaultdict
+from os import PathLike
+from typing import Optional, Union
+import numpy as np
from openmm import app
from openmm import unit as omm_unit
-
from rdkit import Chem
-from rdkit.Chem.rdchem import Mol, Atom, Conformer, EditableMol, BondType
+from rdkit.Chem.rdchem import Atom, BondType, Conformer, EditableMol, Mol
from ..custom_typing import RDKitMol
-from .explicitmoleculecomponent import ExplicitMoleculeComponent
+from ..molhashing import deserialize_numpy, serialize_numpy
from ..vendor.pdb_file.pdbfile import PDBFile
from ..vendor.pdb_file.pdbxfile import PDBxFile
-
-
-from ..molhashing import deserialize_numpy, serialize_numpy
-
+from .explicitmoleculecomponent import ExplicitMoleculeComponent
_BONDORDERS_OPENMM_TO_RDKIT = {
1: BondType.SINGLE,
@@ -34,9 +30,7 @@
app.Aromatic: BondType.AROMATIC,
None: BondType.UNSPECIFIED,
}
-_BONDORDERS_RDKIT_TO_OPENMM = {
- v: k for k, v in _BONDORDERS_OPENMM_TO_RDKIT.items()
-}
+_BONDORDERS_RDKIT_TO_OPENMM = {v: k for k, v in _BONDORDERS_OPENMM_TO_RDKIT.items()}
_BONDORDER_TO_ORDER = {
BondType.UNSPECIFIED: 1, # assumption
BondType.SINGLE: 1,
@@ -50,21 +44,29 @@
_BONDORDER_RDKIT_TO_STR = {v: k for k, v in _BONDORDER_STR_TO_RDKIT.items()}
_CHIRALITY_RDKIT_TO_STR = {
- Chem.CHI_TETRAHEDRAL_CW: 'CW',
- Chem.CHI_TETRAHEDRAL_CCW: 'CCW',
- Chem.CHI_UNSPECIFIED: 'U',
-}
-_CHIRALITY_STR_TO_RDKIT = {
- v: k for k, v in _CHIRALITY_RDKIT_TO_STR.items()
+ Chem.CHI_TETRAHEDRAL_CW: "CW",
+ Chem.CHI_TETRAHEDRAL_CCW: "CCW",
+ Chem.CHI_UNSPECIFIED: "U",
}
+_CHIRALITY_STR_TO_RDKIT = {v: k for k, v in _CHIRALITY_RDKIT_TO_STR.items()}
negative_ions = ["F", "CL", "BR", "I"]
positive_ions = [
# +1
- "LI", "NA", "K", "RB", "CS",
+ "LI",
+ "NA",
+ "K",
+ "RB",
+ "CS",
# +2
- "BE", "MG", "CA", "SR", "BA", "RA", "ZN",
+ "BE",
+ "MG",
+ "CA",
+ "SR",
+ "BA",
+ "RA",
+ "ZN",
]
@@ -91,10 +93,13 @@ class ProteinComponent(ExplicitMoleculeComponent):
edit the molecule do this in an appropriate toolkit **before** creating
an instance from this class.
"""
+
def __init__(self, rdkit: RDKitMol, name=""):
if not all(a.GetMonomerInfo() is not None for a in rdkit.GetAtoms()):
- raise TypeError("Not all atoms in input have MonomerInfo defined. "
- "Consider loading via rdkit.Chem.MolFromPDBFile or similar.")
+ raise TypeError(
+ "Not all atoms in input have MonomerInfo defined. "
+ "Consider loading via rdkit.Chem.MolFromPDBFile or similar."
+ )
super().__init__(rdkit=rdkit, name=name)
# FROM
@@ -116,9 +121,7 @@ def from_pdb_file(cls, pdb_file: str, name: str = ""):
the deserialized molecule
"""
openmm_PDBFile = PDBFile(pdb_file)
- return cls._from_openmmPDBFile(
- openmm_PDBFile=openmm_PDBFile, name=name
- )
+ return cls._from_openmmPDBFile(openmm_PDBFile=openmm_PDBFile, name=name)
@classmethod
def from_pdbx_file(cls, pdbx_file: str, name=""):
@@ -138,13 +141,10 @@ def from_pdbx_file(cls, pdbx_file: str, name=""):
the deserialized molecule
"""
openmm_PDBxFile = PDBxFile(pdbx_file)
- return cls._from_openmmPDBFile(
- openmm_PDBFile=openmm_PDBxFile, name=name
- )
+ return cls._from_openmmPDBFile(openmm_PDBFile=openmm_PDBxFile, name=name)
@classmethod
- def _from_openmmPDBFile(cls, openmm_PDBFile: Union[PDBFile, PDBxFile],
- name: str = ""):
+ def _from_openmmPDBFile(cls, openmm_PDBFile: Union[PDBFile, PDBxFile], name: str = ""):
"""Converts to our internal representation (rdkit Mol)
Parameters
@@ -199,9 +199,7 @@ def _from_openmmPDBFile(cls, openmm_PDBFile: Union[PDBFile, PDBxFile],
# Set Positions
rd_mol = editable_rdmol.GetMol()
- positions = np.array(
- openmm_PDBFile.positions.value_in_unit(omm_unit.angstrom), ndmin=3
- )
+ positions = np.array(openmm_PDBFile.positions.value_in_unit(omm_unit.angstrom), ndmin=3)
for frame_id, frame in enumerate(positions):
conf = Conformer(frame_id)
@@ -216,18 +214,15 @@ def _from_openmmPDBFile(cls, openmm_PDBFile: Union[PDBFile, PDBxFile],
atomic_num = a.GetAtomicNum()
atom_name = a.GetMonomerInfo().GetName()
- connectivity = sum(
- _BONDORDER_TO_ORDER[bond.GetBondType()]
- for bond in a.GetBonds()
- )
+ connectivity = sum(_BONDORDER_TO_ORDER[bond.GetBondType()] for bond in a.GetBonds())
default_valence = periodicTable.GetDefaultValence(atomic_num)
if connectivity == 0: # ions:
# strip catches cases like 'CL1' as name
if atom_name.strip(string.digits).upper() in positive_ions:
- fc = default_valence # e.g. Sodium ions
+ fc = default_valence # e.g. Sodium ions
elif atom_name.strip(string.digits).upper() in negative_ions:
- fc = - default_valence # e.g. Chlorine ions
+ fc = -default_valence # e.g. Chlorine ions
else: # -no-cov-
resn = a.GetMonomerInfo().GetResidueName()
resind = int(a.GetMonomerInfo().GetResidueNumber())
@@ -237,9 +232,9 @@ def _from_openmmPDBFile(cls, openmm_PDBFile: Union[PDBFile, PDBxFile],
f"connectivity{connectivity}"
)
elif default_valence > connectivity:
- fc = - (default_valence - connectivity) # negative charge
+ fc = -(default_valence - connectivity) # negative charge
elif default_valence < connectivity:
- fc = + (connectivity - default_valence) # positive charge
+ fc = +(connectivity - default_valence) # positive charge
else:
fc = 0 # neutral
@@ -272,7 +267,7 @@ def _from_dict(cls, ser_dict: dict, name: str = ""):
mi.SetName(atom[5])
mi.SetResidueName(atom[6])
mi.SetResidueNumber(int(atom[7]))
- mi.SetIsHeteroAtom(atom[8] == 'Y')
+ mi.SetIsHeteroAtom(atom[8] == "Y")
a.SetFormalCharge(atom[9])
a.SetMonomerInfo(mi)
@@ -304,7 +299,7 @@ def _from_dict(cls, ser_dict: dict, name: str = ""):
for bond_id, bond in enumerate(rd_mol.GetBonds()):
# Can't set these on an editable mol, go round a second time
_, _, _, arom = ser_dict["bonds"][bond_id]
- bond.SetIsAromatic(arom == 'Y')
+ bond.SetIsAromatic(arom == "Y")
if "name" in ser_dict:
name = ser_dict["name"]
@@ -319,6 +314,7 @@ def to_openmm_topology(self) -> app.Topology:
openmm.app.Topology
resulting topology obj.
"""
+
def reskey(m):
"""key for defining when a residue has changed from previous
@@ -329,7 +325,7 @@ def reskey(m):
m.GetChainId(),
m.GetResidueName(),
m.GetResidueNumber(),
- m.GetInsertionCode()
+ m.GetInsertionCode(),
)
def chainkey(m):
@@ -362,10 +358,7 @@ def chainkey(m):
if (new_resid := reskey(mi)) != current_resid:
_, resname, resnum, icode = new_resid
- r = top.addResidue(name=resname,
- chain=c,
- id=str(resnum),
- insertionCode=icode)
+ r = top.addResidue(name=resname, chain=c, id=str(resnum), insertionCode=icode)
current_resid = new_resid
a = top.addAtom(
@@ -380,9 +373,7 @@ def chainkey(m):
for bond in self._rdkit.GetBonds():
a1 = atom_lookup[bond.GetBeginAtomIdx()]
a2 = atom_lookup[bond.GetEndAtomIdx()]
- top.addBond(a1, a2,
- order=_BONDORDERS_RDKIT_TO_OPENMM.get(
- bond.GetBondType(), None))
+ top.addBond(a1, a2, order=_BONDORDERS_RDKIT_TO_OPENMM.get(bond.GetBondType(), None))
return top
@@ -400,9 +391,7 @@ def to_openmm_positions(self) -> omm_unit.Quantity:
Quantity containing protein atom positions
"""
np_pos = deserialize_numpy(self.to_dict()["conformers"][0])
- openmm_pos = (
- list(map(lambda x: np.array(x), np_pos)) * omm_unit.angstrom
- )
+ openmm_pos = list(map(lambda x: np.array(x), np_pos)) * omm_unit.angstrom
return openmm_pos
@@ -429,20 +418,18 @@ def to_pdb_file(self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes]
# write file
if not isinstance(out_path, io.TextIOBase):
# allows pathlike/str; we close on completion
- out_file = open(out_path, mode='w') # type: ignore
+ out_file = open(out_path, mode="w") # type: ignore
must_close = True
else:
out_file = out_path # type: ignore
must_close = False
-
+
try:
out_path = out_file.name
except AttributeError:
out_path = ""
- PDBFile.writeFile(
- topology=openmm_top, positions=openmm_pos, file=out_file
- )
+ PDBFile.writeFile(topology=openmm_top, positions=openmm_pos, file=out_file)
if must_close:
# we only close the file if we had to open it
@@ -450,9 +437,7 @@ def to_pdb_file(self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes]
return out_path
- def to_pdbx_file(
- self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes], io.TextIOBase]
- ) -> str:
+ def to_pdbx_file(self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes], io.TextIOBase]) -> str:
"""
serialize protein to pdbx file.
@@ -471,19 +456,17 @@ def to_pdbx_file(
# get pos:
np_pos = deserialize_numpy(self.to_dict()["conformers"][0])
- openmm_pos = (
- list(map(lambda x: np.array(x), np_pos)) * omm_unit.angstrom
- )
+ openmm_pos = list(map(lambda x: np.array(x), np_pos)) * omm_unit.angstrom
# write file
if not isinstance(out_path, io.TextIOBase):
# allows pathlike/str; we close on completion
- out_file = open(out_path, mode='w') # type: ignore
+ out_file = open(out_path, mode="w") # type: ignore
must_close = True
else:
out_file = out_path # type: ignore
must_close = False
-
+
try:
out_path = out_file.name
except AttributeError:
@@ -491,7 +474,6 @@ def to_pdbx_file(
PDBxFile.writeFile(topology=top, positions=openmm_pos, file=out_file)
-
if must_close:
# we only close the file if we had to open it
out_file.close()
@@ -516,7 +498,7 @@ def _to_dict(self) -> dict:
mi.GetName(),
mi.GetResidueName(),
mi.GetResidueNumber(),
- 'Y' if mi.GetIsHeteroAtom() else 'N',
+ "Y" if mi.GetIsHeteroAtom() else "N",
atom.GetFormalCharge(),
)
)
@@ -526,15 +508,14 @@ def _to_dict(self) -> dict:
bond.GetBeginAtomIdx(),
bond.GetEndAtomIdx(),
_BONDORDER_RDKIT_TO_STR[bond.GetBondType()],
- 'Y' if bond.GetIsAromatic() else 'N',
+ "Y" if bond.GetIsAromatic() else "N",
# bond.GetStereo() or "", do we need this? i.e. are openff ffs going to use cis/trans SMARTS?
)
for bond in self._rdkit.GetBonds()
]
conformers = [
- serialize_numpy(conf.GetPositions()) # .m_as(unit.angstrom)
- for conf in self._rdkit.GetConformers()
+ serialize_numpy(conf.GetPositions()) for conf in self._rdkit.GetConformers() # .m_as(unit.angstrom)
]
# Result
diff --git a/gufe/components/smallmoleculecomponent.py b/gufe/components/smallmoleculecomponent.py
index b51f2b7d..814c473b 100644
--- a/gufe/components/smallmoleculecomponent.py
+++ b/gufe/components/smallmoleculecomponent.py
@@ -2,16 +2,17 @@
# For details, see https://github.com/OpenFreeEnergy/gufe
import logging
+import warnings
+
# openff complains about oechem being missing, shhh
-logger = logging.getLogger('openff.toolkit')
+logger = logging.getLogger("openff.toolkit")
logger.setLevel(logging.ERROR)
from typing import Any
from rdkit import Chem
-from .explicitmoleculecomponent import ExplicitMoleculeComponent
from ..molhashing import deserialize_numpy, serialize_numpy
-
+from .explicitmoleculecomponent import ExplicitMoleculeComponent
_INT_TO_ATOMCHIRAL = {
0: Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
@@ -20,14 +21,16 @@
3: Chem.rdchem.ChiralType.CHI_OTHER,
}
# support for non-tetrahedral stereo requires rdkit 2022.09.1+
-if hasattr(Chem.rdchem.ChiralType, 'CHI_TETRAHEDRAL'):
- _INT_TO_ATOMCHIRAL.update({
- 4: Chem.rdchem.ChiralType.CHI_TETRAHEDRAL,
- 5: Chem.rdchem.ChiralType.CHI_ALLENE,
- 6: Chem.rdchem.ChiralType.CHI_SQUAREPLANAR,
- 7: Chem.rdchem.ChiralType.CHI_TRIGONALBIPYRAMIDAL,
- 8: Chem.rdchem.ChiralType.CHI_OCTAHEDRAL,
- })
+if hasattr(Chem.rdchem.ChiralType, "CHI_TETRAHEDRAL"):
+ _INT_TO_ATOMCHIRAL.update(
+ {
+ 4: Chem.rdchem.ChiralType.CHI_TETRAHEDRAL,
+ 5: Chem.rdchem.ChiralType.CHI_ALLENE,
+ 6: Chem.rdchem.ChiralType.CHI_SQUAREPLANAR,
+ 7: Chem.rdchem.ChiralType.CHI_TRIGONALBIPYRAMIDAL,
+ 8: Chem.rdchem.ChiralType.CHI_OCTAHEDRAL,
+ }
+ )
_ATOMCHIRAL_TO_INT = {v: k for k, v in _INT_TO_ATOMCHIRAL.items()}
@@ -53,7 +56,8 @@
18: Chem.rdchem.BondType.DATIVEL,
19: Chem.rdchem.BondType.DATIVER,
20: Chem.rdchem.BondType.OTHER,
- 21: Chem.rdchem.BondType.ZERO}
+ 21: Chem.rdchem.BondType.ZERO,
+}
_BONDTYPE_TO_INT = {v: k for k, v in _INT_TO_BONDTYPE.items()}
_INT_TO_BONDSTEREO = {
0: Chem.rdchem.BondStereo.STEREONONE,
@@ -61,9 +65,24 @@
2: Chem.rdchem.BondStereo.STEREOZ,
3: Chem.rdchem.BondStereo.STEREOE,
4: Chem.rdchem.BondStereo.STEREOCIS,
- 5: Chem.rdchem.BondStereo.STEREOTRANS}
+ 5: Chem.rdchem.BondStereo.STEREOTRANS,
+}
_BONDSTEREO_TO_INT = {v: k for k, v in _INT_TO_BONDSTEREO.items()}
+# following the numbering in rdkit
+_INT_TO_HYBRIDIZATION = {
+ 0: Chem.rdchem.HybridizationType.UNSPECIFIED,
+ 1: Chem.rdchem.HybridizationType.S,
+ 2: Chem.rdchem.HybridizationType.SP,
+ 3: Chem.rdchem.HybridizationType.SP2,
+ 4: Chem.rdchem.HybridizationType.SP3,
+ 5: Chem.rdchem.HybridizationType.SP2D,
+ 6: Chem.rdchem.HybridizationType.SP3D,
+ 7: Chem.rdchem.HybridizationType.SP3D2,
+ 8: Chem.rdchem.HybridizationType.OTHER,
+}
+_HYBRIDIZATION_TO_INT = {v: k for k, v in _INT_TO_HYBRIDIZATION.items()}
+
def _setprops(obj, d: dict) -> None:
# add props onto rdkit "obj" (atom/bond/mol/conformer)
@@ -120,8 +139,8 @@ def to_sdf(self) -> str:
sdf = [Chem.MolToMolBlock(mol)]
for prop in mol.GetPropNames():
val = mol.GetProp(prop)
- sdf.append('> <%s>\n%s\n' % (prop, val))
- sdf.append('$$$$\n')
+ sdf.append(f"> <{prop}>\n{val}\n")
+ sdf.append("$$$$\n")
return "\n".join(sdf)
@classmethod
@@ -210,26 +229,40 @@ def _to_dict(self) -> dict:
atoms = []
for atom in self._rdkit.GetAtoms():
- atoms.append((
- atom.GetAtomicNum(), atom.GetIsotope(), atom.GetFormalCharge(), atom.GetIsAromatic(),
- _ATOMCHIRAL_TO_INT[atom.GetChiralTag()], atom.GetAtomMapNum(),
- atom.GetPropsAsDict(includePrivate=False),
- ))
- output['atoms'] = atoms
+ atoms.append(
+ (
+ atom.GetAtomicNum(),
+ atom.GetIsotope(),
+ atom.GetFormalCharge(),
+ atom.GetIsAromatic(),
+ _ATOMCHIRAL_TO_INT[atom.GetChiralTag()],
+ atom.GetAtomMapNum(),
+ atom.GetPropsAsDict(includePrivate=False),
+ _HYBRIDIZATION_TO_INT[atom.GetHybridization()],
+ )
+ )
+ output["atoms"] = atoms
bonds = []
for bond in self._rdkit.GetBonds():
- bonds.append((
- bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), _BONDTYPE_TO_INT[bond.GetBondType()],
- _BONDSTEREO_TO_INT[bond.GetStereo()],
- bond.GetPropsAsDict(includePrivate=False)
- ))
- output['bonds'] = bonds
+ bonds.append(
+ (
+ bond.GetBeginAtomIdx(),
+ bond.GetEndAtomIdx(),
+ _BONDTYPE_TO_INT[bond.GetBondType()],
+ _BONDSTEREO_TO_INT[bond.GetStereo()],
+ bond.GetPropsAsDict(includePrivate=False),
+ )
+ )
+ output["bonds"] = bonds
conf = self._rdkit.GetConformer()
- output['conformer'] = (serialize_numpy(conf.GetPositions()), conf.GetPropsAsDict(includePrivate=False))
+ output["conformer"] = (
+ serialize_numpy(conf.GetPositions()),
+ conf.GetPropsAsDict(includePrivate=False),
+ )
- output['molprops'] = self._rdkit.GetPropsAsDict(includePrivate=False)
+ output["molprops"] = self._rdkit.GetPropsAsDict(includePrivate=False)
return output
@@ -239,7 +272,7 @@ def _from_dict(cls, d: dict):
m = Chem.Mol()
em = Chem.EditableMol(m)
- for atom in d['atoms']:
+ for atom in d["atoms"]:
a = Chem.Atom(atom[0])
a.SetIsotope(atom[1])
a.SetFormalCharge(atom[2])
@@ -247,26 +280,36 @@ def _from_dict(cls, d: dict):
a.SetChiralTag(_INT_TO_ATOMCHIRAL[atom[4]])
a.SetAtomMapNum(atom[5])
_setprops(a, atom[6])
+ try:
+ a.SetHybridization(_INT_TO_HYBRIDIZATION[atom[7]])
+ except IndexError:
+ warnings.warn(
+ "The atom hybridization data was not found and has been set to unspecified. This can be"
+ " fixed by recreating the SmallMoleculeComponent from the rdkit molecule after running "
+ "sanitization."
+ )
+ pass
+
em.AddAtom(a)
- for bond in d['bonds']:
+ for bond in d["bonds"]:
em.AddBond(bond[0], bond[1], _INT_TO_BONDTYPE[bond[2]])
# other fields are applied onto the ROMol
m = em.GetMol()
- for bond, b in zip(d['bonds'], m.GetBonds()):
+ for bond, b in zip(d["bonds"], m.GetBonds()):
b.SetStereo(_INT_TO_BONDSTEREO[bond[3]])
_setprops(b, bond[4])
- pos = deserialize_numpy(d['conformer'][0])
+ pos = deserialize_numpy(d["conformer"][0])
c = Chem.Conformer(m.GetNumAtoms())
for i, p in enumerate(pos):
c.SetAtomPosition(i, p)
- _setprops(c, d['conformer'][1])
+ _setprops(c, d["conformer"][1])
m.AddConformer(c)
- _setprops(m, d['molprops'])
+ _setprops(m, d["molprops"])
m.UpdatePropertyCache()
@@ -275,10 +318,10 @@ def _from_dict(cls, d: dict):
def copy_with_replacements(self, **replacements):
# this implementation first makes a copy with the name replaced
# only, then does any other replacements that are necessary
- if 'name' in replacements:
- name = replacements.pop('name')
+ if "name" in replacements:
+ name = replacements.pop("name")
dct = self._to_dict()
- dct['molprops']['ofe-name'] = name
+ dct["molprops"]["ofe-name"] = name
obj = self._from_dict(dct)
else:
obj = self
diff --git a/gufe/components/solventcomponent.py b/gufe/components/solventcomponent.py
index a323a3c6..8b9b01e1 100644
--- a/gufe/components/solventcomponent.py
+++ b/gufe/components/solventcomponent.py
@@ -2,13 +2,14 @@
# For details, see https://github.com/OpenFreeEnergy/gufe
from __future__ import annotations
-from openff.units import unit
from typing import Optional, Tuple
+from openff.units import unit
+
from .component import Component
-_CATIONS = {'Cs', 'K', 'Li', 'Na', 'Rb'}
-_ANIONS = {'Cl', 'Br', 'F', 'I'}
+_CATIONS = {"Cs", "K", "Li", "Na", "Rb"}
+_ANIONS = {"Cl", "Br", "F", "I"}
# really wanted to make this a dataclass but then can't sort & strip ion input
@@ -21,18 +22,22 @@ class SolventComponent(Component):
and their coordinates. This abstract representation is later made concrete
by specific MD engine methods.
"""
+
_smiles: str
- _positive_ion: Optional[str]
- _negative_ion: Optional[str]
+ _positive_ion: str | None
+ _negative_ion: str | None
_neutralize: bool
_ion_concentration: unit.Quantity
- def __init__(self, *, # force kwarg usage
- smiles: str = 'O',
- positive_ion: str = 'Na+',
- negative_ion: str = 'Cl-',
- neutralize: bool = True,
- ion_concentration: unit.Quantity = 0.15 * unit.molar):
+ def __init__(
+ self,
+ *, # force kwarg usage
+ smiles: str = "O",
+ positive_ion: str = "Na+",
+ negative_ion: str = "Cl-",
+ neutralize: bool = True,
+ ion_concentration: unit.Quantity = 0.15 * unit.molar,
+ ):
"""
Parameters
----------
@@ -58,26 +63,23 @@ def __init__(self, *, # force kwarg usage
"""
self._smiles = smiles
- norm = positive_ion.strip('-+').capitalize()
+ norm = positive_ion.strip("-+").capitalize()
if norm not in _CATIONS:
raise ValueError(f"Invalid positive ion, got {positive_ion}")
- positive_ion = norm + '+'
+ positive_ion = norm + "+"
self._positive_ion = positive_ion
- norm = negative_ion.strip('-+').capitalize()
+ norm = negative_ion.strip("-+").capitalize()
if norm not in _ANIONS:
raise ValueError(f"Invalid negative ion, got {negative_ion}")
- negative_ion = norm + '-'
+ negative_ion = norm + "-"
self._negative_ion = negative_ion
self._neutralize = neutralize
- if (not isinstance(ion_concentration, unit.Quantity)
- or not ion_concentration.is_compatible_with(unit.molar)):
- raise ValueError(f"ion_concentration must be given in units of"
- f" concentration, got: {ion_concentration}")
+ if not isinstance(ion_concentration, unit.Quantity) or not ion_concentration.is_compatible_with(unit.molar):
+ raise ValueError(f"ion_concentration must be given in units of" f" concentration, got: {ion_concentration}")
if ion_concentration.m < 0:
- raise ValueError(f"ion_concentration must be positive, "
- f"got: {ion_concentration}")
+ raise ValueError(f"ion_concentration must be positive, " f"got: {ion_concentration}")
self._ion_concentration = ion_concentration
@@ -91,12 +93,12 @@ def smiles(self) -> str:
return self._smiles
@property
- def positive_ion(self) -> Optional[str]:
+ def positive_ion(self) -> str | None:
"""The cation in the solvent state"""
return self._positive_ion
@property
- def negative_ion(self) -> Optional[str]:
+ def negative_ion(self) -> str | None:
"""The anion in the solvent state"""
return self._negative_ion
@@ -118,8 +120,8 @@ def total_charge(self):
@classmethod
def _from_dict(cls, d):
"""Deserialize from dict representation"""
- ion_conc = d['ion_concentration']
- d['ion_concentration'] = unit.parse_expression(ion_conc)
+ ion_conc = d["ion_concentration"]
+ d["ion_concentration"] = unit.parse_expression(ion_conc)
return cls(**d)
@@ -127,10 +129,13 @@ def _to_dict(self):
"""For serialization"""
ion_conc = str(self.ion_concentration)
- return {'smiles': self.smiles, 'positive_ion': self.positive_ion,
- 'negative_ion': self.negative_ion,
- 'ion_concentration': ion_conc,
- 'neutralize': self._neutralize}
+ return {
+ "smiles": self.smiles,
+ "positive_ion": self.positive_ion,
+ "negative_ion": self.negative_ion,
+ "ion_concentration": ion_conc,
+ "neutralize": self._neutralize,
+ }
@classmethod
def _defaults(cls):
diff --git a/gufe/custom_codecs.py b/gufe/custom_codecs.py
index a2ed55ea..7499ccd6 100644
--- a/gufe/custom_codecs.py
+++ b/gufe/custom_codecs.py
@@ -5,10 +5,10 @@
import datetime
import functools
import pathlib
+from uuid import UUID
import numpy as np
from openff.units import DEFAULT_UNIT_REGISTRY
-from uuid import UUID
import gufe
from gufe.custom_json import JSONCodec
@@ -62,15 +62,15 @@ def is_openff_quantity_dict(dct):
BYTES_CODEC = JSONCodec(
cls=bytes,
- to_dict=lambda obj: {'latin-1': obj.decode('latin-1')},
- from_dict=lambda dct: dct['latin-1'].encode('latin-1'),
+ to_dict=lambda obj: {"latin-1": obj.decode("latin-1")},
+ from_dict=lambda dct: dct["latin-1"].encode("latin-1"),
)
DATETIME_CODEC = JSONCodec(
cls=datetime.datetime,
- to_dict=lambda obj: {'isotime': obj.isoformat()},
- from_dict=lambda dct: datetime.datetime.fromisoformat(dct['isotime']),
+ to_dict=lambda obj: {"isotime": obj.isoformat()},
+ from_dict=lambda dct: datetime.datetime.fromisoformat(dct["isotime"]),
)
# Note that this has inconsistent behaviour for some generic types
@@ -80,12 +80,10 @@ def is_openff_quantity_dict(dct):
NPY_DTYPE_CODEC = JSONCodec(
cls=np.generic,
to_dict=lambda obj: {
- 'dtype': str(obj.dtype),
- 'bytes': obj.tobytes(),
+ "dtype": str(obj.dtype),
+ "bytes": obj.tobytes(),
},
- from_dict=lambda dct: np.frombuffer(
- dct['bytes'], dtype=np.dtype(dct['dtype'])
- )[0],
+ from_dict=lambda dct: np.frombuffer(dct["bytes"], dtype=np.dtype(dct["dtype"]))[0],
is_my_obj=lambda obj: isinstance(obj, np.generic),
is_my_dict=is_npy_dtype_dict,
)
@@ -94,13 +92,11 @@ def is_openff_quantity_dict(dct):
NUMPY_CODEC = JSONCodec(
cls=np.ndarray,
to_dict=lambda obj: {
- 'dtype': str(obj.dtype),
- 'shape': list(obj.shape),
- 'bytes': obj.tobytes()
+ "dtype": str(obj.dtype),
+ "shape": list(obj.shape),
+ "bytes": obj.tobytes(),
},
- from_dict=lambda dct: np.frombuffer(
- dct['bytes'], dtype=np.dtype(dct['dtype'])
- ).reshape(dct['shape'])
+ from_dict=lambda dct: np.frombuffer(dct["bytes"], dtype=np.dtype(dct["dtype"])).reshape(dct["shape"]),
)
@@ -120,7 +116,7 @@ def is_openff_quantity_dict(dct):
":is_custom:": True,
"pint_unit_registry": "openff_units",
},
- from_dict=lambda dct: dct['magnitude'] * DEFAULT_UNIT_REGISTRY.Quantity(dct['unit']),
+ from_dict=lambda dct: dct["magnitude"] * DEFAULT_UNIT_REGISTRY.Quantity(dct["unit"]),
is_my_obj=lambda obj: isinstance(obj, DEFAULT_UNIT_REGISTRY.Quantity),
is_my_dict=is_openff_quantity_dict,
)
diff --git a/gufe/custom_json.py b/gufe/custom_json.py
index fc14a9aa..535487f6 100644
--- a/gufe/custom_json.py
+++ b/gufe/custom_json.py
@@ -5,10 +5,11 @@
import functools
import json
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
+from collections.abc import Iterable
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
-class JSONCodec(object):
+class JSONCodec:
"""Custom JSON encoding and decoding for non-default types.
Parameters
@@ -32,13 +33,14 @@ class JSONCodec(object):
by this decoder. Default behavior assumes usage of the default
``is_my_obj``.
"""
+
def __init__(
self,
cls: Union[type, None],
- to_dict: Callable[[Any], Dict],
- from_dict: Callable[[Dict], Any],
+ to_dict: Callable[[Any], dict],
+ from_dict: Callable[[dict], Any],
is_my_obj: Optional[Callable[[Any], bool]] = None,
- is_my_dict=None
+ is_my_dict=None,
):
if is_my_obj is None:
is_my_obj = self._is_my_obj
@@ -53,7 +55,7 @@ def __init__(
self.is_my_dict = is_my_dict
def _is_my_dict(self, dct: dict) -> bool:
- expected = ['__class__', '__module__', ':is_custom:']
+ expected = ["__class__", "__module__", ":is_custom:"]
is_custom = all(exp in dct for exp in expected)
return (
is_custom
@@ -80,7 +82,7 @@ def default(self, obj: Any) -> Any:
return dct
return obj
- def object_hook(self, dct: Dict) -> Any:
+ def object_hook(self, dct: dict) -> Any:
if self.is_my_dict(dct):
obj = self.from_dict(dct)
return obj
@@ -88,8 +90,8 @@ def object_hook(self, dct: Dict) -> Any:
def custom_json_factory(
- coding_methods: Iterable[JSONCodec]
-) -> Tuple[Type[json.JSONEncoder], Type[json.JSONDecoder]]:
+ coding_methods: Iterable[JSONCodec],
+) -> tuple[type[json.JSONEncoder], type[json.JSONDecoder]]:
"""Create JSONEncoder/JSONDecoder for special types.
Factory method. Dynamically creates classes that enable all the provided
@@ -107,6 +109,7 @@ def custom_json_factory(
subclasses of JSONEncoder/JSONDecoder that use support the provided
codecs
"""
+
class CustomJSONEncoder(json.JSONEncoder):
def default(self, obj):
for coding_method in coding_methods:
@@ -125,9 +128,7 @@ class CustomJSONDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
# technically, JSONDecoder doesn't come with an object_hook
# method, which is why we pass it to super here
- super(CustomJSONDecoder, self).__init__(
- object_hook=self.object_hook, *args, **kwargs
- )
+ super().__init__(object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, dct):
for coding_method in coding_methods:
@@ -144,8 +145,8 @@ def object_hook(self, dct):
return (CustomJSONEncoder, CustomJSONDecoder)
-class JSONSerializerDeserializer(object):
- """
+class JSONSerializerDeserializer:
+ r"""
Tools to serialize and deserialize objects as JSON.
This wrapper object is necessary so that we can register new codecs
@@ -164,15 +165,17 @@ class JSONSerializerDeserializer(object):
codecs : list of :class:`.JSONCodec`\s
codecs supported
"""
+
def __init__(self, codecs: Iterable[JSONCodec]):
- self.codecs: List[JSONCodec] = []
+ self.codecs: list[JSONCodec] = []
for codec in codecs:
self.add_codec(codec)
self.encoder, self.decoder = self._set_serialization()
- def _set_serialization(self) -> Tuple[Type[json.JSONEncoder],
- Type[json.JSONDecoder]]:
+ def _set_serialization(
+ self,
+ ) -> tuple[type[json.JSONEncoder], type[json.JSONDecoder]]:
encoder, decoder = custom_json_factory(self.codecs)
self._serializer = functools.partial(json.dumps, cls=encoder)
self._deserializer = functools.partial(json.loads, cls=decoder)
@@ -194,7 +197,6 @@ def add_codec(self, codec: JSONCodec):
self.encoder, self.decoder = self._set_serialization()
-
def serializer(self, obj: Any) -> str:
"""Callable that dumps to JSON"""
return self._serializer(obj)
diff --git a/gufe/custom_typing.py b/gufe/custom_typing.py
index 3a1ce2a9..38f91625 100644
--- a/gufe/custom_typing.py
+++ b/gufe/custom_typing.py
@@ -2,6 +2,7 @@
# For details, see https://github.com/OpenFreeEnergy/gufe
from typing import TypeVar
+
from rdkit import Chem
try:
diff --git a/gufe/ligandnetwork.py b/gufe/ligandnetwork.py
index c43dce57..45eac118 100644
--- a/gufe/ligandnetwork.py
+++ b/gufe/ligandnetwork.py
@@ -2,21 +2,24 @@
# For details, see https://github.com/OpenFreeEnergy/gufe
from __future__ import annotations
-from itertools import chain
import json
+from collections.abc import Iterable
+from itertools import chain
+from typing import FrozenSet, Optional
+
import networkx as nx
-from typing import FrozenSet, Iterable, Optional
-import gufe
+import gufe
from gufe import SmallMoleculeComponent
+
from .mapping import LigandAtomMapping
-from .tokenization import GufeTokenizable, JSON_HANDLER
+from .tokenization import JSON_HANDLER, GufeTokenizable
class LigandNetwork(GufeTokenizable):
"""A directed graph connecting ligands according to their atom mapping.
A network can be defined by specifying only edges, in which case the nodes are implicitly added.
-
+
Parameters
----------
@@ -26,17 +29,17 @@ class LigandNetwork(GufeTokenizable):
Nodes for this network. Any nodes already included as a part of the 'edges' will be ignored.
Nodes not already included in 'edges' will be added as isolated, unconnected nodes.
"""
+
def __init__(
self,
edges: Iterable[LigandAtomMapping],
- nodes: Optional[Iterable[SmallMoleculeComponent]] = None
+ nodes: Iterable[SmallMoleculeComponent] | None = None,
):
if nodes is None:
nodes = []
self._edges = frozenset(edges)
- edge_nodes = set(chain.from_iterable((e.componentA, e.componentB)
- for e in edges))
+ edge_nodes = set(chain.from_iterable((e.componentA, e.componentB) for e in edges))
self._nodes = frozenset(edge_nodes) | frozenset(nodes)
self._graph = None
@@ -45,11 +48,11 @@ def _defaults(cls):
return {}
def _to_dict(self) -> dict:
- return {'graphml': self.to_graphml()}
+ return {"graphml": self.to_graphml()}
@classmethod
def _from_dict(cls, dct: dict):
- return cls.from_graphml(dct['graphml'])
+ return cls.from_graphml(dct["graphml"])
@property
def graph(self) -> nx.MultiDiGraph:
@@ -65,20 +68,19 @@ def graph(self) -> nx.MultiDiGraph:
for node in sorted(self._nodes):
graph.add_node(node)
for edge in sorted(self._edges):
- graph.add_edge(edge.componentA, edge.componentB, object=edge,
- **edge.annotations)
+ graph.add_edge(edge.componentA, edge.componentB, object=edge, **edge.annotations)
self._graph = nx.freeze(graph)
return self._graph
@property
- def edges(self) -> FrozenSet[LigandAtomMapping]:
+ def edges(self) -> frozenset[LigandAtomMapping]:
"""A read-only view of the edges of the Network"""
return self._edges
@property
- def nodes(self) -> FrozenSet[SmallMoleculeComponent]:
+ def nodes(self) -> frozenset[SmallMoleculeComponent]:
"""A read-only view of the nodes of the Network"""
return self._nodes
@@ -93,25 +95,24 @@ def _serializable_graph(self) -> nx.Graph:
# identical networks will show no changes if you diff their
# serialized versions
sorted_nodes = sorted(self.nodes, key=lambda m: (m.smiles, m.name))
- mol_to_label = {mol: f"mol{num}"
- for num, mol in enumerate(sorted_nodes)}
-
- edge_data = sorted([
- (
- mol_to_label[edge.componentA],
- mol_to_label[edge.componentB],
- json.dumps(list(edge.componentA_to_componentB.items())),
- json.dumps(edge.annotations, cls=JSON_HANDLER.encoder),
- )
- for edge in self.edges
- ])
+ mol_to_label = {mol: f"mol{num}" for num, mol in enumerate(sorted_nodes)}
+
+ edge_data = sorted(
+ [
+ (
+ mol_to_label[edge.componentA],
+ mol_to_label[edge.componentB],
+ json.dumps(list(edge.componentA_to_componentB.items())),
+ json.dumps(edge.annotations, cls=JSON_HANDLER.encoder),
+ )
+ for edge in self.edges
+ ]
+ )
# from here, we just build the graph
serializable_graph = nx.MultiDiGraph()
for mol, label in mol_to_label.items():
- serializable_graph.add_node(label,
- moldict=json.dumps(mol.to_dict(),
- sort_keys=True))
+ serializable_graph.add_node(label, moldict=json.dumps(mol.to_dict(), sort_keys=True))
for molA, molB, mapping, annotation in edge_data:
serializable_graph.add_edge(molA, molB, mapping=mapping, annotations=annotation)
@@ -124,15 +125,19 @@ def _from_serializable_graph(cls, graph: nx.Graph):
This is the inverse of ``_serializable_graph``.
"""
- label_to_mol = {node: SmallMoleculeComponent.from_dict(json.loads(d))
- for node, d in graph.nodes(data='moldict')}
+ label_to_mol = {
+ node: SmallMoleculeComponent.from_dict(json.loads(d)) for node, d in graph.nodes(data="moldict")
+ }
edges = [
- LigandAtomMapping(componentA=label_to_mol[node1],
- componentB=label_to_mol[node2],
- componentA_to_componentB=dict(json.loads(edge_data["mapping"])),
- annotations=json.loads(edge_data.get("annotations", 'null'), cls=JSON_HANDLER.decoder) # work around old graphml files with missing edge annotations
- )
+ LigandAtomMapping(
+ componentA=label_to_mol[node1],
+ componentB=label_to_mol[node2],
+ componentA_to_componentB=dict(json.loads(edge_data["mapping"])),
+ annotations=json.loads(
+ edge_data.get("annotations", "null"), cls=JSON_HANDLER.decoder
+ ), # work around old graphml files with missing edge annotations
+ )
for node1, node2, edge_data in graph.edges(data=True)
]
@@ -183,10 +188,10 @@ def enlarge_graph(self, *, edges=None, nodes=None) -> LigandNetwork:
a new network adding the given edges and nodes to this network
"""
if edges is None:
- edges = set([])
+ edges = set()
if nodes is None:
- nodes = set([])
+ nodes = set()
return LigandNetwork(self.edges | set(edges), self.nodes | set(nodes))
@@ -198,7 +203,7 @@ def _to_rfe_alchemical_network(
*,
alchemical_label: str = "ligand",
autoname=True,
- autoname_prefix=""
+ autoname_prefix="",
) -> gufe.AlchemicalNetwork:
"""
Parameters
@@ -228,8 +233,7 @@ def sys_from_dict(component):
"""
syscomps = {alchemical_label: component}
other_labels = set(labels) - {alchemical_label}
- syscomps.update({label: components[label]
- for label in other_labels})
+ syscomps.update({label: components[label] for label in other_labels})
if autoname:
name = f"{component.name}_{leg_name}"
@@ -246,9 +250,7 @@ def sys_from_dict(component):
else:
name = ""
- transformation = gufe.Transformation(sysA, sysB, protocol,
- mapping=edge,
- name=name)
+ transformation = gufe.Transformation(sysA, sysB, protocol, mapping=edge, name=name)
transformations.append(transformation)
@@ -262,7 +264,7 @@ def to_rbfe_alchemical_network(
*,
autoname: bool = True,
autoname_prefix: str = "easy_rbfe",
- **other_components
+ **other_components,
) -> gufe.AlchemicalNetwork:
"""Convert the ligand network to an AlchemicalNetwork
@@ -279,22 +281,17 @@ def to_rbfe_alchemical_network(
additional non-alchemical components, keyword will be the string
label for the component
"""
- components = {
- 'protein': protein,
- 'solvent': solvent,
- **other_components
- }
+ components = {"protein": protein, "solvent": solvent, **other_components}
leg_labels = {
"solvent": ["ligand", "solvent"],
- "complex": (["ligand", "solvent", "protein"]
- + list(other_components)),
+ "complex": (["ligand", "solvent", "protein"] + list(other_components)),
}
return self._to_rfe_alchemical_network(
components=components,
leg_labels=leg_labels,
protocol=protocol,
autoname=autoname,
- autoname_prefix=autoname_prefix
+ autoname_prefix=autoname_prefix,
)
# on hold until we figure out how to best hack in the PME/NoCutoff
diff --git a/gufe/mapping/__init__.py b/gufe/mapping/__init__.py
index bd6be51b..e77a7208 100644
--- a/gufe/mapping/__init__.py
+++ b/gufe/mapping/__init__.py
@@ -1,7 +1,7 @@
# This code is part of gufe and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
"""Defining the relationship between different components"""
-from .componentmapping import ComponentMapping
-from .atom_mapping import AtomMapping
from .atom_mapper import AtomMapper
+from .atom_mapping import AtomMapping
+from .componentmapping import ComponentMapping
from .ligandatommapping import LigandAtomMapping
diff --git a/gufe/mapping/atom_mapper.py b/gufe/mapping/atom_mapper.py
index 0fe1c9cd..c2844979 100644
--- a/gufe/mapping/atom_mapper.py
+++ b/gufe/mapping/atom_mapper.py
@@ -2,6 +2,7 @@
# For details, see https://github.com/OpenFreeEnergy/gufe
import abc
from collections.abc import Iterator
+
import gufe
from ..tokenization import GufeTokenizable
@@ -18,10 +19,7 @@ class AtomMapper(GufeTokenizable):
"""
@abc.abstractmethod
- def suggest_mappings(self,
- A: gufe.Component,
- B: gufe.Component
- ) -> Iterator[AtomMapping]:
+ def suggest_mappings(self, A: gufe.Component, B: gufe.Component) -> Iterator[AtomMapping]:
"""Suggests possible mappings between two Components
Suggests zero or more :class:`.AtomMapping` objects, which are possible
diff --git a/gufe/mapping/atom_mapping.py b/gufe/mapping/atom_mapping.py
index 58c617f3..e91bf2d4 100644
--- a/gufe/mapping/atom_mapping.py
+++ b/gufe/mapping/atom_mapping.py
@@ -2,10 +2,10 @@
# For details, see https://github.com/OpenFreeEnergy/gufe
import abc
-from collections.abc import Mapping, Iterable
-
+from collections.abc import Iterable, Mapping
import gufe
+
from .componentmapping import ComponentMapping
diff --git a/gufe/mapping/componentmapping.py b/gufe/mapping/componentmapping.py
index 92a098f3..9c1c6965 100644
--- a/gufe/mapping/componentmapping.py
+++ b/gufe/mapping/componentmapping.py
@@ -11,6 +11,7 @@ class ComponentMapping(GufeTokenizable, abc.ABC):
For components that are atom-based is specialised to :class:`.AtomMapping`
"""
+
_componentA: gufe.Component
_componentB: gufe.Component
diff --git a/gufe/mapping/ligandatommapping.py b/gufe/mapping/ligandatommapping.py
index cc8ada71..4d322000 100644
--- a/gufe/mapping/ligandatommapping.py
+++ b/gufe/mapping/ligandatommapping.py
@@ -4,13 +4,15 @@
import json
from typing import Any, Optional
+
import numpy as np
from numpy.typing import NDArray
from gufe.components import SmallMoleculeComponent
from gufe.visualization.mapping_visualization import draw_mapping
-from . import AtomMapping
+
from ..tokenization import JSON_HANDLER
+from . import AtomMapping
class LigandAtomMapping(AtomMapping):
@@ -21,6 +23,7 @@ class LigandAtomMapping(AtomMapping):
:class:`.SmallMoleculeComponent` which stores the mapping as a dict of
integers.
"""
+
componentA: SmallMoleculeComponent
componentB: SmallMoleculeComponent
_annotations: dict[str, Any]
@@ -31,7 +34,7 @@ def __init__(
componentA: SmallMoleculeComponent,
componentB: SmallMoleculeComponent,
componentA_to_componentB: dict[int, int],
- annotations: Optional[dict[str, Any]] = None,
+ annotations: dict[str, Any] | None = None,
):
"""
Parameters
@@ -57,11 +60,9 @@ def __init__(
nB = self.componentB.to_rdkit().GetNumAtoms()
for i, j in componentA_to_componentB.items():
if not (0 <= i < nA):
- raise ValueError(f"Got invalid index for ComponentA ({i}); "
- f"must be 0 <= n < {nA}")
+ raise ValueError(f"Got invalid index for ComponentA ({i}); " f"must be 0 <= n < {nA}")
if not (0 <= j < nB):
- raise ValueError(f"Got invalid index for ComponentB ({i}); "
- f"must be 0 <= n < {nB}")
+ raise ValueError(f"Got invalid index for ComponentB ({i}); " f"must be 0 <= n < {nB}")
self._compA_to_compB = componentA_to_componentB
@@ -84,48 +85,51 @@ def componentB_to_componentA(self) -> dict[int, int]:
@property
def componentA_unique(self):
- return (i for i in range(self.componentA.to_rdkit().GetNumAtoms())
- if i not in self._compA_to_compB)
+ return (i for i in range(self.componentA.to_rdkit().GetNumAtoms()) if i not in self._compA_to_compB)
@property
def componentB_unique(self):
- return (i for i in range(self.componentB.to_rdkit().GetNumAtoms())
- if i not in self._compA_to_compB.values())
+ return (i for i in range(self.componentB.to_rdkit().GetNumAtoms()) if i not in self._compA_to_compB.values())
@property
def annotations(self):
"""Any extra metadata, for example the score of a mapping"""
# return a copy (including copy of nested)
- return json.loads(json.dumps(self._annotations, cls=JSON_HANDLER.encoder), cls=JSON_HANDLER.decoder)
+ return json.loads(
+ json.dumps(self._annotations, cls=JSON_HANDLER.encoder),
+ cls=JSON_HANDLER.decoder,
+ )
def _to_dict(self):
"""Serialize to dict"""
return {
- 'componentA': self.componentA,
- 'componentB': self.componentB,
- 'componentA_to_componentB': self._compA_to_compB,
- 'annotations': json.dumps(self._annotations, sort_keys=True, cls=JSON_HANDLER.encoder),
+ "componentA": self.componentA,
+ "componentB": self.componentB,
+ "componentA_to_componentB": self._compA_to_compB,
+ "annotations": json.dumps(self._annotations, sort_keys=True, cls=JSON_HANDLER.encoder),
}
@classmethod
def _from_dict(cls, d: dict):
"""Deserialize from dict"""
# the mapping dict gets mangled sometimes
- mapping = d['componentA_to_componentB']
+ mapping = d["componentA_to_componentB"]
fixed = {int(k): int(v) for k, v in mapping.items()}
return cls(
- componentA=d['componentA'],
- componentB=d['componentB'],
+ componentA=d["componentA"],
+ componentB=d["componentB"],
componentA_to_componentB=fixed,
- annotations=json.loads(d['annotations'], cls=JSON_HANDLER.decoder)
+ annotations=json.loads(d["annotations"], cls=JSON_HANDLER.decoder),
)
def __repr__(self):
- return (f"{self.__class__.__name__}(componentA={self.componentA!r}, "
- f"componentB={self.componentB!r}, "
- f"componentA_to_componentB={self._compA_to_compB!r}, "
- f"annotations={self.annotations!r})")
+ return (
+ f"{self.__class__.__name__}(componentA={self.componentA!r}, "
+ f"componentB={self.componentB!r}, "
+ f"componentA_to_componentB={self._compA_to_compB!r}, "
+ f"annotations={self.annotations!r})"
+ )
def _ipython_display_(self, d2d=None): # pragma: no-cover
"""
@@ -144,9 +148,16 @@ def _ipython_display_(self, d2d=None): # pragma: no-cover
"""
from IPython.display import Image, display
- return display(Image(draw_mapping(self._compA_to_compB,
- self.componentA.to_rdkit(),
- self.componentB.to_rdkit(), d2d)))
+ return display(
+ Image(
+ draw_mapping(
+ self._compA_to_compB,
+ self.componentA.to_rdkit(),
+ self.componentB.to_rdkit(),
+ d2d,
+ )
+ )
+ )
def draw_to_file(self, fname: str, d2d=None):
"""
@@ -162,16 +173,26 @@ def draw_to_file(self, fname: str, d2d=None):
fname : str
Name of file to save atom map
"""
- data = draw_mapping(self._compA_to_compB, self.componentA.to_rdkit(),
- self.componentB.to_rdkit(), d2d)
+ data = draw_mapping(
+ self._compA_to_compB,
+ self.componentA.to_rdkit(),
+ self.componentB.to_rdkit(),
+ d2d,
+ )
if type(data) == bytes:
mode = "wb"
else:
mode = "w"
with open(fname, mode) as f:
- f.write(draw_mapping(self._compA_to_compB, self.componentA.to_rdkit(),
- self.componentB.to_rdkit(), d2d))
+ f.write(
+ draw_mapping(
+ self._compA_to_compB,
+ self.componentA.to_rdkit(),
+ self.componentB.to_rdkit(),
+ d2d,
+ )
+ )
def with_annotations(self, annotations: dict[str, Any]) -> LigandAtomMapping:
"""Create a new mapping based on this one with extra annotations.
@@ -187,7 +208,7 @@ def with_annotations(self, annotations: dict[str, Any]) -> LigandAtomMapping:
componentA=self.componentA,
componentB=self.componentB,
componentA_to_componentB=self._compA_to_compB,
- annotations=dict(**self.annotations, **annotations)
+ annotations=dict(**self.annotations, **annotations),
)
def get_distances(self) -> NDArray[np.float64]:
diff --git a/gufe/molhashing.py b/gufe/molhashing.py
index 37bd64cf..5c8daf08 100644
--- a/gufe/molhashing.py
+++ b/gufe/molhashing.py
@@ -1,6 +1,7 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
import io
+
import numpy as np
@@ -15,7 +16,7 @@ def serialize_numpy(arr: np.ndarray) -> str:
np.save(npbytes, arr, allow_pickle=False)
npbytes.seek(0)
# latin-1 or base64? latin-1 is fewer bytes, but arguably worse on eyes
- return npbytes.read().decode('latin-1')
+ return npbytes.read().decode("latin-1")
def deserialize_numpy(arr_str: str) -> np.ndarray:
@@ -25,6 +26,6 @@ def deserialize_numpy(arr_str: str) -> np.ndarray:
-------
:func:`.serialize_numpy`
"""
- npbytes = io.BytesIO(arr_str.encode('latin-1'))
+ npbytes = io.BytesIO(arr_str.encode("latin-1"))
npbytes.seek(0)
return np.load(npbytes)
diff --git a/gufe/network.py b/gufe/network.py
index 59929d68..18bb06d9 100644
--- a/gufe/network.py
+++ b/gufe/network.py
@@ -1,17 +1,17 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
-from typing import Generator, Iterable, Optional
-from typing_extensions import Self # Self is included in typing as of python 3.11
+from collections.abc import Generator, Iterable
+from typing import Optional
import networkx as nx
-from .tokenization import GufeTokenizable
+from typing_extensions import Self # Self is included in typing as of python 3.11
from .chemicalsystem import ChemicalSystem
+from .tokenization import GufeTokenizable
from .transformations import Transformation
-
class AlchemicalNetwork(GufeTokenizable):
_edges: frozenset[Transformation]
_nodes: frozenset[ChemicalSystem]
@@ -31,6 +31,7 @@ class AlchemicalNetwork(GufeTokenizable):
the individual chemical states. :class:`.ChemicalSystem` objects from
Transformation objects in edges will be automatically extracted
"""
+
def __init__(
self,
edges: Optional[Iterable[Transformation]] = None,
@@ -47,11 +48,7 @@ def __init__(
else:
self._nodes = frozenset(nodes)
- self._nodes = (
- self._nodes
- | frozenset(e.stateA for e in self._edges)
- | frozenset(e.stateB for e in self._edges)
- )
+ self._nodes = self._nodes | frozenset(e.stateA for e in self._edges) | frozenset(e.stateB for e in self._edges)
self._graph = None
@@ -60,9 +57,7 @@ def _generate_graph(edges, nodes) -> nx.MultiDiGraph:
g = nx.MultiDiGraph()
for transformation in edges:
- g.add_edge(
- transformation.stateA, transformation.stateB, object=transformation
- )
+ g.add_edge(transformation.stateA, transformation.stateB, object=transformation)
g.add_nodes_from(nodes)
@@ -99,15 +94,15 @@ def name(self) -> Optional[str]:
return self._name
def _to_dict(self) -> dict:
- return {"nodes": sorted(self.nodes),
- "edges": sorted(self.edges),
- "name": self.name}
+ return {
+ "nodes": sorted(self.nodes),
+ "edges": sorted(self.edges),
+ "name": self.name,
+ }
@classmethod
def _from_dict(cls, d: dict) -> Self:
- return cls(nodes=frozenset(d['nodes']),
- edges=frozenset(d['edges']),
- name=d.get('name'))
+ return cls(nodes=frozenset(d["nodes"]), edges=frozenset(d["edges"]), name=d.get("name"))
@classmethod
def _defaults(cls):
@@ -126,7 +121,7 @@ def from_graphml(cls, str) -> Self:
def _from_nx_graph(cls, nx_graph) -> Self:
"""Create an alchemical network from a networkx representation."""
chemical_systems = [n for n in nx_graph.nodes()]
- transformations = [e[2]['object'] for e in nx_graph.edges(data=True)]
+ transformations = [e[2]["object"] for e in nx_graph.edges(data=True)]
return cls(nodes=chemical_systems, edges=transformations)
def connected_subgraphs(self) -> Generator[Self, None, None]:
@@ -135,4 +130,4 @@ def connected_subgraphs(self) -> Generator[Self, None, None]:
for node_group in node_groups:
nx_subgraph = self.graph.subgraph(node_group)
alc_subgraph = self._from_nx_graph(nx_subgraph)
- yield(alc_subgraph)
+ yield (alc_subgraph)
diff --git a/gufe/protocols/__init__.py b/gufe/protocols/__init__.py
index 9b13eff6..fca9b580 100644
--- a/gufe/protocols/__init__.py
+++ b/gufe/protocols/__init__.py
@@ -1,4 +1,5 @@
"""Defining processes for performing estimates of free energy differences"""
+
from .protocol import Protocol, ProtocolResult
from .protocoldag import ProtocolDAG, ProtocolDAGResult, execute_DAG
-from .protocolunit import ProtocolUnit, ProtocolUnitResult, ProtocolUnitFailure, Context
+from .protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult
diff --git a/gufe/protocols/protocol.py b/gufe/protocols/protocol.py
index bfda10c8..d086a2be 100644
--- a/gufe/protocols/protocol.py
+++ b/gufe/protocols/protocol.py
@@ -6,15 +6,16 @@
"""
import abc
-from typing import Optional, Iterable, Any, Union
-from openff.units import Quantity
import warnings
+from collections.abc import Iterable, Sized
+from typing import Any, Optional, Union
+
+from openff.units import Quantity
-from ..settings import Settings, SettingsBaseModel
-from ..tokenization import GufeTokenizable, GufeKey
from ..chemicalsystem import ChemicalSystem
from ..mapping import ComponentMapping
-
+from ..settings import Settings, SettingsBaseModel
+from ..tokenization import GufeKey, GufeTokenizable
from .protocoldag import ProtocolDAG, ProtocolDAGResult
from .protocolunit import ProtocolUnit
@@ -32,19 +33,33 @@ class ProtocolResult(GufeTokenizable):
- `get_uncertainty`
"""
- def __init__(self, **data):
+ def __init__(self, n_protocol_dag_results: int = 0, **data):
self._data = data
+ if not n_protocol_dag_results >= 0:
+ raise ValueError("`n_protocol_dag_results` must be an integer greater than or equal to zero")
+
+ self._n_protocol_dag_results = n_protocol_dag_results
+
@classmethod
def _defaults(cls):
return {}
def _to_dict(self):
- return {'data': self.data}
+ return {"n_protocol_dag_results": self.n_protocol_dag_results, "data": self.data}
@classmethod
def _from_dict(cls, dct: dict):
- return cls(**dct['data'])
+ # TODO: remove in gufe 2.0
+ try:
+ n_protocol_dag_results = dct["n_protocol_dag_results"]
+ except KeyError:
+ n_protocol_dag_results = 0
+ return cls(n_protocol_dag_results=n_protocol_dag_results, **dct["data"])
+
+ @property
+ def n_protocol_dag_results(self) -> int:
+ return self._n_protocol_dag_results
@property
def data(self) -> dict[str, Any]:
@@ -57,12 +72,10 @@ def data(self) -> dict[str, Any]:
return self._data
@abc.abstractmethod
- def get_estimate(self) -> Quantity:
- ...
+ def get_estimate(self) -> Quantity: ...
@abc.abstractmethod
- def get_uncertainty(self) -> Quantity:
- ...
+ def get_uncertainty(self) -> Quantity: ...
class Protocol(GufeTokenizable):
@@ -80,6 +93,7 @@ class Protocol(GufeTokenizable):
- `_gather`
- `_default_settings`
"""
+
_settings: Settings
result_cls: type[ProtocolResult]
"""Corresponding `ProtocolResult` subclass."""
@@ -109,7 +123,7 @@ def _defaults(cls):
return {}
def _to_dict(self):
- return {'settings': self.settings}
+ return {"settings": self.settings}
@classmethod
def _from_dict(cls, dct: dict):
@@ -180,9 +194,9 @@ def create(
mapping: Optional[Union[ComponentMapping, list[ComponentMapping], dict[str, ComponentMapping]]],
extends: Optional[ProtocolDAGResult] = None,
name: Optional[str] = None,
- transformation_key: Optional[GufeKey] = None
+ transformation_key: Optional[GufeKey] = None,
) -> ProtocolDAG:
- """Prepare a `ProtocolDAG` with all information required for execution.
+ r"""Prepare a `ProtocolDAG` with all information required for execution.
A :class:`.ProtocolDAG` is composed of :class:`.ProtocolUnit` \s, with
dependencies established between them. These form a directed, acyclic
@@ -221,9 +235,10 @@ def create(
A directed, acyclic graph that can be executed by a `Scheduler`.
"""
if isinstance(mapping, dict):
- warnings.warn(("mapping input as a dict is deprecated, "
- "instead use either a single Mapping or list"),
- DeprecationWarning)
+ warnings.warn(
+ ("mapping input as a dict is deprecated, " "instead use either a single Mapping or list"),
+ DeprecationWarning,
+ )
mapping = list(mapping.values())
return ProtocolDAG(
@@ -235,12 +250,10 @@ def create(
extends=extends,
),
transformation_key=transformation_key,
- extends_key=extends.key if extends is not None else None
+ extends_key=extends.key if extends is not None else None,
)
- def gather(
- self, protocol_dag_results: Iterable[ProtocolDAGResult]
- ) -> ProtocolResult:
+ def gather(self, protocol_dag_results: Iterable[ProtocolDAGResult]) -> ProtocolResult:
"""Gather multiple ProtocolDAGResults into a single ProtocolResult.
Parameters
@@ -254,12 +267,16 @@ def gather(
ProtocolResult
Aggregated results from many `ProtocolDAGResult`s from a given `Protocol`.
"""
- return self.result_cls(**self._gather(protocol_dag_results))
+ # Iterable does not implement __len__ and makes no guarantees that
+ # protocol_dag_results is finite, checking both in method signature
+ # doesn't appear possible, explicitly check for __len__ through the
+ # Sized type
+ if not isinstance(protocol_dag_results, Sized):
+ raise ValueError("`protocol_dag_results` must implement `__len__`")
+ return self.result_cls(n_protocol_dag_results=len(protocol_dag_results), **self._gather(protocol_dag_results))
@abc.abstractmethod
- def _gather(
- self, protocol_dag_results: Iterable[ProtocolDAGResult]
- ) -> dict[str, Any]:
+ def _gather(self, protocol_dag_results: Iterable[ProtocolDAGResult]) -> dict[str, Any]:
"""Method to override in custom Protocol subclasses.
This method should take any number of ``ProtocolDAGResult``s produced
diff --git a/gufe/protocols/protocoldag.py b/gufe/protocols/protocoldag.py
index d6958b76..74e7afbd 100644
--- a/gufe/protocols/protocoldag.py
+++ b/gufe/protocols/protocoldag.py
@@ -2,20 +2,19 @@
# For details, see https://github.com/OpenFreeEnergy/gufe
import abc
-from copy import copy
-from collections import defaultdict
import os
-from typing import Iterable, Optional, Union, Any
+import shutil
+from collections import defaultdict
+from collections.abc import Iterable
+from copy import copy
from os import PathLike
from pathlib import Path
-import shutil
+from typing import Any, Optional, Union
import networkx as nx
-from ..tokenization import GufeTokenizable, GufeKey
-from .protocolunit import (
- ProtocolUnit, ProtocolUnitResult, ProtocolUnitFailure, Context
-)
+from ..tokenization import GufeKey, GufeTokenizable
+from .protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult
class DAGMixin:
@@ -32,7 +31,7 @@ class DAGMixin:
## key of the ProtocolDAG this DAG extends
_extends_key: Optional[GufeKey]
- @staticmethod
+ @staticmethod
def _build_graph(nodes):
"""Build dependency DAG of ProtocolUnits with input keys stored on edges"""
G = nx.DiGraph()
@@ -46,9 +45,7 @@ def _build_graph(nodes):
@staticmethod
def _iterate_dag_order(graph):
- return reversed(
- list(nx.lexicographical_topological_sort(graph, key=lambda pu: pu.key))
- )
+ return reversed(list(nx.lexicographical_topological_sort(graph, key=lambda pu: pu.key)))
@property
def name(self) -> Optional[str]:
@@ -104,13 +101,13 @@ class ProtocolDAGResult(GufeTokenizable, DAGMixin):
There may be many of these for a given `Transformation`. Data elements from
these objects are combined by `Protocol.gather` into a `ProtocolResult`.
"""
+
_protocol_unit_results: list[ProtocolUnitResult]
_unit_result_mapping: dict[ProtocolUnit, list[ProtocolUnitResult]]
_result_unit_mapping: dict[ProtocolUnitResult, ProtocolUnit]
-
def __init__(
- self,
+ self,
*,
protocol_units: list[ProtocolUnit],
protocol_unit_results: list[ProtocolUnitResult],
@@ -147,11 +144,13 @@ def _defaults(cls):
return {}
def _to_dict(self):
- return {'name': self.name,
- 'protocol_units': self._protocol_units,
- 'protocol_unit_results': self._protocol_unit_results,
- 'transformation_key': self._transformation_key,
- 'extends_key': self._extends_key}
+ return {
+ "name": self.name,
+ "protocol_units": self._protocol_units,
+ "protocol_unit_results": self._protocol_unit_results,
+ "transformation_key": self._transformation_key,
+ "extends_key": self._extends_key,
+ }
@classmethod
def _from_dict(cls, dct: dict):
@@ -190,7 +189,7 @@ def protocol_unit_failures(self) -> list[ProtocolUnitFailure]:
# mypy can't figure out the types here, .ok() will ensure a certain type
# https://mypy.readthedocs.io/en/stable/common_issues.html?highlight=cast#complex-type-tests
return [r for r in self.protocol_unit_results if not r.ok()] # type: ignore
-
+
@property
def protocol_unit_successes(self) -> list[ProtocolUnitResult]:
"""A list of only successful `ProtocolUnit` results.
@@ -252,8 +251,7 @@ def result_to_unit(self, protocol_unit_result: ProtocolUnitResult) -> ProtocolUn
def ok(self) -> bool:
# ensure that for every protocol unit, there is an OK result object
- return all(any(pur.ok() for pur in self._unit_result_mapping[pu])
- for pu in self._protocol_units)
+ return all(any(pur.ok() for pur in self._unit_result_mapping[pu]) for pu in self._protocol_units)
@property
def terminal_protocol_unit_results(self) -> list[ProtocolUnitResult]:
@@ -265,8 +263,7 @@ def terminal_protocol_unit_results(self) -> list[ProtocolUnitResult]:
All ProtocolUnitResults which do not have a ProtocolUnitResult that
follows on (depends) on them.
"""
- return [u for u in self._protocol_unit_results
- if not nx.ancestors(self._result_graph, u)]
+ return [u for u in self._protocol_unit_results if not nx.ancestors(self._result_graph, u)]
class ProtocolDAG(GufeTokenizable, DAGMixin):
@@ -336,24 +333,28 @@ def _defaults(cls):
return {}
def _to_dict(self):
- return {'name': self.name,
- 'protocol_units': self.protocol_units,
- 'transformation_key': self._transformation_key,
- 'extends_key': self._extends_key}
+ return {
+ "name": self.name,
+ "protocol_units": self.protocol_units,
+ "transformation_key": self._transformation_key,
+ "extends_key": self._extends_key,
+ }
@classmethod
def _from_dict(cls, dct: dict):
return cls(**dct)
-def execute_DAG(protocoldag: ProtocolDAG, *,
- shared_basedir: Path,
- scratch_basedir: Path,
- keep_shared: bool = False,
- keep_scratch: bool = False,
- raise_error: bool = True,
- n_retries: int = 0,
- ) -> ProtocolDAGResult:
+def execute_DAG(
+ protocoldag: ProtocolDAG,
+ *,
+ shared_basedir: Path,
+ scratch_basedir: Path,
+ keep_shared: bool = False,
+ keep_scratch: bool = False,
+ raise_error: bool = True,
+ n_retries: int = 0,
+) -> ProtocolDAGResult:
"""
Locally execute a full :class:`ProtocolDAG` in serial and in-process.
@@ -400,21 +401,17 @@ def execute_DAG(protocoldag: ProtocolDAG, *,
attempt = 0
while attempt <= n_retries:
- shared = shared_basedir / f'shared_{str(unit.key)}_attempt_{attempt}'
+ shared = shared_basedir / f"shared_{str(unit.key)}_attempt_{attempt}"
shared_paths.append(shared)
shared.mkdir()
- scratch = scratch_basedir / f'scratch_{str(unit.key)}_attempt_{attempt}'
+ scratch = scratch_basedir / f"scratch_{str(unit.key)}_attempt_{attempt}"
scratch.mkdir()
- context = Context(shared=shared,
- scratch=scratch)
+ context = Context(shared=shared, scratch=scratch)
# execute
- result = unit.execute(
- context=context,
- raise_error=raise_error,
- **inputs)
+ result = unit.execute(context=context, raise_error=raise_error, **inputs)
all_results.append(result)
if not keep_scratch:
@@ -434,16 +431,18 @@ def execute_DAG(protocoldag: ProtocolDAG, *,
shutil.rmtree(shared_path)
return ProtocolDAGResult(
- name=protocoldag.name,
- protocol_units=protocoldag.protocol_units,
- protocol_unit_results=all_results,
- transformation_key=protocoldag.transformation_key,
- extends_key=protocoldag.extends_key)
+ name=protocoldag.name,
+ protocol_units=protocoldag.protocol_units,
+ protocol_unit_results=all_results,
+ transformation_key=protocoldag.transformation_key,
+ extends_key=protocoldag.extends_key,
+ )
def _pu_to_pur(
- inputs: Union[dict[str, Any], list[Any], ProtocolUnit],
- mapping: dict[GufeKey, ProtocolUnitResult]):
+ inputs: Union[dict[str, Any], list[Any], ProtocolUnit],
+ mapping: dict[GufeKey, ProtocolUnitResult],
+):
"""Convert each `ProtocolUnit` found within `inputs` to its corresponding
`ProtocolUnitResult`.
@@ -467,4 +466,3 @@ def _pu_to_pur(
return mapping[inputs.key]
else:
return inputs
-
diff --git a/gufe/protocols/protocolunit.py b/gufe/protocols/protocolunit.py
index 73728bdf..fba47adb 100644
--- a/gufe/protocols/protocolunit.py
+++ b/gufe/protocols/protocolunit.py
@@ -8,20 +8,19 @@
from __future__ import annotations
import abc
-from dataclasses import dataclass
import datetime
import sys
+import tempfile
import traceback
import uuid
+from collections.abc import Iterable
+from copy import copy
+from dataclasses import dataclass
from os import PathLike
from pathlib import Path
-from copy import copy
-from typing import Iterable, Tuple, List, Dict, Any, Optional, Union
-import tempfile
+from typing import Any, Dict, List, Optional, Tuple, Union
-from ..tokenization import (
- GufeTokenizable, GufeKey, TOKENIZABLE_REGISTRY
-)
+from ..tokenization import TOKENIZABLE_REGISTRY, GufeKey, GufeTokenizable
@dataclass
@@ -30,6 +29,7 @@ class Context:
`ProtocolUnit._execute`.
"""
+
scratch: PathLike
shared: PathLike
@@ -55,14 +55,16 @@ class ProtocolUnitResult(GufeTokenizable):
Successful result of a single :class:`ProtocolUnit` execution.
"""
- def __init__(self, *,
- name: Optional[str] = None,
- source_key: GufeKey,
- inputs: Dict[str, Any],
- outputs: Dict[str, Any],
- start_time: Optional[datetime.datetime] = None,
- end_time: Optional[datetime.datetime] = None,
- ):
+ def __init__(
+ self,
+ *,
+ name: str | None = None,
+ source_key: GufeKey,
+ inputs: dict[str, Any],
+ outputs: dict[str, Any],
+ start_time: datetime.datetime | None = None,
+ end_time: datetime.datetime | None = None,
+ ):
"""
Parameters
----------
@@ -102,16 +104,18 @@ def _defaults(cls):
return {}
def _to_dict(self):
- return {'name': self.name,
- '_key': self.key,
- 'source_key': self.source_key,
- 'inputs': self.inputs,
- 'outputs': self.outputs,
- 'start_time': self.start_time,
- 'end_time': self.end_time}
+ return {
+ "name": self.name,
+ "_key": self.key,
+ "source_key": self.source_key,
+ "inputs": self.inputs,
+ "outputs": self.outputs,
+ "start_time": self.start_time,
+ "end_time": self.end_time,
+ }
@classmethod
- def _from_dict(cls, dct: Dict):
+ def _from_dict(cls, dct: dict):
key = dct.pop("_key")
obj = cls(**dct)
obj._set_key(key)
@@ -138,19 +142,19 @@ def dependencies(self) -> list[ProtocolUnitResult]:
"""All results that this result was dependent on"""
if self._dependencies is None:
self._dependencies = _list_dependencies(self._inputs, ProtocolUnitResult)
- return self._dependencies # type: ignore
+ return self._dependencies # type: ignore
@staticmethod
def ok() -> bool:
return True
@property
- def start_time(self) -> Optional[datetime.datetime]:
+ def start_time(self) -> datetime.datetime | None:
"""The time execution of this Unit began"""
return self._start_time
@property
- def end_time(self) -> Optional[datetime.datetime]:
+ def end_time(self) -> datetime.datetime | None:
"""The time at which execution of this Unit finished"""
return self._end_time
@@ -159,18 +163,18 @@ class ProtocolUnitFailure(ProtocolUnitResult):
"""Failed result of a single :class:`ProtocolUnit` execution."""
def __init__(
- self,
- *,
- name=None,
- source_key,
- inputs,
- outputs,
- _key=None,
- exception,
- traceback,
- start_time: Optional[datetime.datetime] = None,
- end_time: Optional[datetime.datetime] = None,
- ):
+ self,
+ *,
+ name=None,
+ source_key,
+ inputs,
+ outputs,
+ _key=None,
+ exception,
+ traceback,
+ start_time: datetime.datetime | None = None,
+ end_time: datetime.datetime | None = None,
+ ):
"""
Parameters
----------
@@ -194,17 +198,22 @@ def __init__(
"""
self._exception = exception
self._traceback = traceback
- super().__init__(name=name, source_key=source_key, inputs=inputs, outputs=outputs,
- start_time=start_time, end_time=end_time)
+ super().__init__(
+ name=name,
+ source_key=source_key,
+ inputs=inputs,
+ outputs=outputs,
+ start_time=start_time,
+ end_time=end_time,
+ )
def _to_dict(self):
dct = super()._to_dict()
- dct.update({'exception': self.exception,
- 'traceback': self.traceback})
+ dct.update({"exception": self.exception, "traceback": self.traceback})
return dct
@property
- def exception(self) -> Tuple[str, Tuple[Any, ...]]:
+ def exception(self) -> tuple[str, tuple[Any, ...]]:
return self._exception
@property
@@ -218,21 +227,17 @@ def ok() -> bool:
class ProtocolUnit(GufeTokenizable):
"""A unit of work within a ProtocolDAG."""
- _dependencies: Optional[list[ProtocolUnit]]
- def __init__(
- self,
- *,
- name: Optional[str] = None,
- **inputs
- ):
+ _dependencies: list[ProtocolUnit] | None
+
+ def __init__(self, *, name: str | None = None, **inputs):
"""Create an instance of a ProtocolUnit.
Parameters
----------
name : str
- Custom name to give this
- **inputs
+ Custom name to give this
+ **inputs
Keyword arguments, which can include other `ProtocolUnit`s on which
this `ProtocolUnit` is dependent. Should be either `gufe` objects
or JSON-serializables.
@@ -257,28 +262,25 @@ def _defaults(cls):
return {}
def _to_dict(self):
- return {'inputs': self.inputs,
- 'name': self.name,
- '_key': self.key}
+ return {"inputs": self.inputs, "name": self.name, "_key": self.key}
@classmethod
- def _from_dict(cls, dct: Dict):
- _key = dct.pop('_key')
+ def _from_dict(cls, dct: dict):
+ _key = dct.pop("_key")
- obj = cls(name=dct['name'],
- **dct['inputs'])
+ obj = cls(name=dct["name"], **dct["inputs"])
obj._set_key(_key)
return obj
@property
- def name(self) -> Optional[str]:
+ def name(self) -> str | None:
"""
Optional name for the `ProtocolUnit`.
"""
return self._name
@property
- def inputs(self) -> Dict[str, Any]:
+ def inputs(self) -> dict[str, Any]:
"""
Inputs to the `ProtocolUnit`.
@@ -291,12 +293,11 @@ def dependencies(self) -> list[ProtocolUnit]:
"""All units that this unit is dependent on (parents)"""
if self._dependencies is None:
self._dependencies = _list_dependencies(self._inputs, ProtocolUnit)
- return self._dependencies # type: ignore
+ return self._dependencies # type: ignore
- def execute(self, *,
- context: Context,
- raise_error: bool = False,
- **inputs) -> Union[ProtocolUnitResult, ProtocolUnitFailure]:
+ def execute(
+ self, *, context: Context, raise_error: bool = False, **inputs
+ ) -> ProtocolUnitResult | ProtocolUnitFailure:
"""Given `ProtocolUnitResult` s from dependencies, execute this `ProtocolUnit`.
Parameters
@@ -313,14 +314,18 @@ def execute(self, *,
objects this unit is dependent on.
"""
- result: Union[ProtocolUnitResult, ProtocolUnitFailure]
+ result: ProtocolUnitResult | ProtocolUnitFailure
start = datetime.datetime.now()
try:
outputs = self._execute(context, **inputs)
result = ProtocolUnitResult(
- name=self.name, source_key=self.key, inputs=inputs, outputs=outputs,
- start_time=start, end_time=datetime.datetime.now(),
+ name=self.name,
+ source_key=self.key,
+ inputs=inputs,
+ outputs=outputs,
+ start_time=start,
+ end_time=datetime.datetime.now(),
)
except KeyboardInterrupt:
@@ -345,7 +350,7 @@ def execute(self, *,
@staticmethod
@abc.abstractmethod
- def _execute(ctx: Context, **inputs) -> Dict[str, Any]:
+ def _execute(ctx: Context, **inputs) -> dict[str, Any]:
"""Method to override in custom `ProtocolUnit` subclasses.
A `Context` is always given as its first argument, which provides execution
@@ -362,15 +367,15 @@ def _execute(ctx: Context, **inputs) -> Dict[str, Any]:
where instantiation with the subclass `MyProtocolUnit` would look like:
- >>> unit = MyProtocolUnit(settings=settings_dict,
+ >>> unit = MyProtocolUnit(settings=settings_dict,
initialization=another_protocolunit,
some_arg=7,
another_arg="five")
- Inside of `_execute` above:
+ Inside of `_execute` above:
- `settings`, and `some_arg`, would have their values set as given
- `initialization` would get the `ProtocolUnitResult` that comes from
- `another_protocolunit`'s own execution.
+ `another_protocolunit`'s own execution.
- `another_arg` would be accessible via `inputs['another_arg']`
This allows protocol developers to define how `ProtocolUnit`s are
diff --git a/gufe/settings/__init__.py b/gufe/settings/__init__.py
index 01cf0e78..be4c3fde 100644
--- a/gufe/settings/__init__.py
+++ b/gufe/settings/__init__.py
@@ -1,10 +1,4 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
"""General models for defining the parameters that protocols use"""
-from .models import (
- Settings,
- ThermoSettings,
- BaseForceFieldSettings,
- OpenMMSystemGeneratorFFSettings,
- SettingsBaseModel,
-)
+from .models import BaseForceFieldSettings, OpenMMSystemGeneratorFFSettings, Settings, SettingsBaseModel, ThermoSettings
diff --git a/gufe/settings/models.py b/gufe/settings/models.py
index 08d3973c..734fcb56 100644
--- a/gufe/settings/models.py
+++ b/gufe/settings/models.py
@@ -5,21 +5,15 @@
"""
import abc
+import pprint
from typing import Optional, Union
from openff.models.models import DefaultModel
from openff.models.types import FloatQuantity
from openff.units import unit
-import pprint
try:
- from pydantic.v1 import (
- Extra,
- Field,
- PositiveFloat,
- PrivateAttr,
- validator,
- )
+ from pydantic.v1 import Extra, Field, PositiveFloat, PrivateAttr, validator
except ImportError:
from pydantic import (
Extra,
@@ -28,17 +22,20 @@
PrivateAttr,
validator,
)
+
import pydantic
class SettingsBaseModel(DefaultModel):
"""Settings and modifications we want for all settings classes."""
+
_is_frozen: bool = PrivateAttr(default_factory=lambda: False)
class Config:
"""
:noindex:
"""
+
extra = pydantic.Extra.forbid
arbitrary_types_allowed = False
smart_union = True
@@ -56,8 +53,7 @@ def frozen_copy(self):
def freeze_model(model):
submodels = (
- mod for field in model.__fields__
- if isinstance(mod := getattr(model, field), SettingsBaseModel)
+ mod for field in model.__fields__ if isinstance(mod := getattr(model, field), SettingsBaseModel)
)
for mod in submodels:
freeze_model(mod)
@@ -78,8 +74,7 @@ def unfrozen_copy(self):
def unfreeze_model(model):
submodels = (
- mod for field in model.__fields__
- if isinstance(mod := getattr(model, field), SettingsBaseModel)
+ mod for field in model.__fields__ if isinstance(mod := getattr(model, field), SettingsBaseModel)
)
for mod in submodels:
unfreeze_model(mod)
@@ -100,7 +95,8 @@ def __setattr__(self, name, value):
raise AttributeError(
f"Cannot set '{name}': Settings are immutable once attached"
" to a Protocol and cannot be modified. Modify Settings "
- "*before* creating the Protocol.")
+ "*before* creating the Protocol."
+ )
return super().__setattr__(name, value)
@@ -112,23 +108,22 @@ class ThermoSettings(SettingsBaseModel):
possible.
"""
- temperature: FloatQuantity["kelvin"] = Field(
- None, description="Simulation temperature, default units kelvin"
- )
+ temperature: FloatQuantity["kelvin"] = Field(None, description="Simulation temperature, default units kelvin")
pressure: FloatQuantity["standard_atmosphere"] = Field(
None, description="Simulation pressure, default units standard atmosphere (atm)"
)
ph: Union[PositiveFloat, None] = Field(None, description="Simulation pH")
- redox_potential: Optional[float] = Field(
- None, description="Simulation redox potential"
- )
+ redox_potential: Optional[float] = Field(None, description="Simulation redox potential")
class BaseForceFieldSettings(SettingsBaseModel, abc.ABC):
"""Base class for ForceFieldSettings objects"""
+
class Config:
""":noindex:"""
+
pass
+
...
@@ -145,11 +140,13 @@ class OpenMMSystemGeneratorFFSettings(BaseForceFieldSettings):
.. _`OpenMMForceField SystemGenerator documentation`:
https://github.com/openmm/openmmforcefields#automating-force-field-management-with-systemgenerator
"""
+
class Config:
""":noindex:"""
+
pass
-
- constraints: Optional[str] = 'hbonds'
+
+ constraints: Optional[str] = "hbonds"
"""Constraints to be applied to system.
One of 'hbonds', 'allbonds', 'hangles' or None, default 'hbonds'"""
rigid_water: bool = True
@@ -169,39 +166,37 @@ class Config:
small_molecule_forcefield: str = "openff-2.1.1" # other default ideas 'openff-2.0.0', 'gaff-2.11', 'espaloma-0.2.0'
"""Name of the force field to be used for :class:`SmallMoleculeComponent` """
- nonbonded_method = 'PME'
+ nonbonded_method = "PME"
"""
Method for treating nonbonded interactions, currently only PME and
NoCutoff are allowed. Default PME.
"""
- nonbonded_cutoff: FloatQuantity['nanometer'] = 1.0 * unit.nanometer
+ nonbonded_cutoff: FloatQuantity["nanometer"] = 1.0 * unit.nanometer
"""
Cutoff value for short range nonbonded interactions.
Default 1.0 * unit.nanometer.
"""
- @validator('nonbonded_method')
+ @validator("nonbonded_method")
def allowed_nonbonded(cls, v):
- if v.lower() not in ['pme', 'nocutoff']:
- errmsg = (
- "Only PME and NoCutoff are allowed nonbonded_methods")
+ if v.lower() not in ["pme", "nocutoff"]:
+ errmsg = "Only PME and NoCutoff are allowed nonbonded_methods"
raise ValueError(errmsg)
return v
- @validator('nonbonded_cutoff')
+ @validator("nonbonded_cutoff")
def is_positive_distance(cls, v):
# these are time units, not simulation steps
if not v.is_compatible_with(unit.nanometer):
- raise ValueError("nonbonded_cutoff must be in distance units "
- "(i.e. nanometers)")
+ raise ValueError("nonbonded_cutoff must be in distance units " "(i.e. nanometers)")
if v < 0:
errmsg = "nonbonded_cutoff must be a positive value"
raise ValueError(errmsg)
return v
- @validator('constraints')
+ @validator("constraints")
def constraint_check(cls, v):
- allowed = {'hbonds', 'hangles', 'allbonds'}
+ allowed = {"hbonds", "hangles", "allbonds"}
if not (v is None or v.lower() in allowed):
raise ValueError(f"Bad constraints value, use one of {allowed}")
@@ -217,6 +212,7 @@ class Settings(SettingsBaseModel):
Protocols can subclass this to extend this to cater for their additional settings.
"""
+
forcefield_settings: BaseForceFieldSettings
thermo_settings: ThermoSettings
diff --git a/gufe/storage/__init__.py b/gufe/storage/__init__.py
index 27bcd5ae..853c55a2 100644
--- a/gufe/storage/__init__.py
+++ b/gufe/storage/__init__.py
@@ -1 +1 @@
-"""How to store objects across simulation campaigns"""
\ No newline at end of file
+"""How to store objects across simulation campaigns"""
diff --git a/gufe/storage/errors.py b/gufe/storage/errors.py
index ca3b9889..079e452a 100644
--- a/gufe/storage/errors.py
+++ b/gufe/storage/errors.py
@@ -1,8 +1,10 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
+
class ExternalResourceError(Exception):
"""Base class for errors due to problems with external resources"""
+
# TODO: is it necessary to have a base class here? Would you ever have
# one catch that handles both subclass errors?
diff --git a/gufe/storage/externalresource/base.py b/gufe/storage/externalresource/base.py
index f72a3de6..ae32d75a 100644
--- a/gufe/storage/externalresource/base.py
+++ b/gufe/storage/externalresource/base.py
@@ -1,18 +1,16 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
import abc
+import dataclasses
+import glob
import hashlib
-import pathlib
-import shutil
import io
import os
-import glob
-from typing import Union, Tuple, ContextManager
-import dataclasses
+import pathlib
+import shutil
+from typing import ContextManager, Tuple, Union
-from ..errors import (
- MissingExternalResourceError, ChangedExternalResourceError
-)
+from ..errors import ChangedExternalResourceError, MissingExternalResourceError
@dataclasses.dataclass
@@ -31,6 +29,7 @@ class _ForceContext:
Filelike objects can often be used with explicit open/close. This
requires the returned byteslike to be consumed as a context manager.
"""
+
def __init__(self, context):
self._context = context
diff --git a/gufe/storage/externalresource/filestorage.py b/gufe/storage/externalresource/filestorage.py
index 5330f88b..e2b5a54b 100644
--- a/gufe/storage/externalresource/filestorage.py
+++ b/gufe/storage/externalresource/filestorage.py
@@ -1,16 +1,13 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
+import os
import pathlib
import shutil
-import os
-from typing import Union, Tuple, ContextManager
+from typing import ContextManager, Tuple, Union
+from ..errors import ChangedExternalResourceError, MissingExternalResourceError
from .base import ExternalStorage
-from ..errors import (
- MissingExternalResourceError, ChangedExternalResourceError
-)
-
# TODO: this should use pydantic to check init inputs
class FileStorage(ExternalStorage):
@@ -21,10 +18,7 @@ def _exists(self, location):
return self._as_path(location).exists()
def __eq__(self, other):
- return (
- isinstance(other, FileStorage)
- and self.root_dir == other.root_dir
- )
+ return isinstance(other, FileStorage) and self.root_dir == other.root_dir
def _store_bytes(self, location, byte_data):
path = self._as_path(location)
@@ -32,7 +26,7 @@ def _store_bytes(self, location, byte_data):
filename = path.name
# TODO: add some stuff here to catch permissions-based errors
directory.mkdir(parents=True, exist_ok=True)
- with open(path, mode='wb') as f:
+ with open(path, mode="wb") as f:
f.write(byte_data)
def _store_path(self, location, path):
@@ -55,9 +49,7 @@ def _delete(self, location):
if self.exists(location):
path.unlink()
else:
- raise MissingExternalResourceError(
- f"Unable to delete '{str(path)}': File does not exist"
- )
+ raise MissingExternalResourceError(f"Unable to delete '{str(path)}': File does not exist")
def _as_path(self, location):
return self.root_dir / pathlib.Path(location)
@@ -73,6 +65,6 @@ def _get_filename(self, location):
def _load_stream(self, location):
try:
- return open(self._as_path(location), 'rb')
+ return open(self._as_path(location), "rb")
except OSError as e:
raise MissingExternalResourceError(str(e))
diff --git a/gufe/storage/externalresource/memorystorage.py b/gufe/storage/externalresource/memorystorage.py
index 40c05087..e8ee07d3 100644
--- a/gufe/storage/externalresource/memorystorage.py
+++ b/gufe/storage/externalresource/memorystorage.py
@@ -1,17 +1,15 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
import io
-from typing import Union, Tuple, ContextManager
+from typing import ContextManager, Tuple, Union
+from ..errors import ChangedExternalResourceError, MissingExternalResourceError
from .base import ExternalStorage
-from ..errors import (
- MissingExternalResourceError, ChangedExternalResourceError
-)
-
class MemoryStorage(ExternalStorage):
"""Not for production use, but potentially useful in testing"""
+
def __init__(self):
self._data = {}
@@ -22,9 +20,7 @@ def _delete(self, location):
try:
del self._data[location]
except KeyError:
- raise MissingExternalResourceError(
- f"Unable to delete '{location}': key does not exist"
- )
+ raise MissingExternalResourceError(f"Unable to delete '{location}': key does not exist")
def __eq__(self, other):
return self is other
@@ -34,7 +30,7 @@ def _store_bytes(self, location, byte_data):
return location, self.get_metadata(location)
def _store_path(self, location, path):
- with open(path, 'rb') as f:
+ with open(path, "rb") as f:
byte_data = f.read()
return self._store_bytes(location, byte_data)
diff --git a/gufe/tests/conftest.py b/gufe/tests/conftest.py
index 3080f469..63dd30a0 100644
--- a/gufe/tests/conftest.py
+++ b/gufe/tests/conftest.py
@@ -1,15 +1,17 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
+import functools
import importlib.resources
+import io
import urllib.request
from urllib.error import URLError
-import io
-import functools
+
import pytest
+from openff.units import unit
from rdkit import Chem
from rdkit.Chem import AllChem
-from openff.units import unit
+
import gufe
from gufe.tests.test_protocol import DummyProtocol
@@ -22,7 +24,7 @@
class URLFileLike:
- def __init__(self, url, encoding='utf-8'):
+ def __init__(self, url, encoding="utf-8"):
self.url = url
self.encoding = encoding
self.data = None
@@ -39,21 +41,21 @@ def __call__(self):
def get_test_filename(filename):
- with importlib.resources.path('gufe.tests.data', filename) as file:
+ with importlib.resources.path("gufe.tests.data", filename) as file:
return str(file)
_benchmark_pdb_names = [
- "cmet_protein",
- "hif2a_protein",
- "mcl1_protein",
- "p38_protein",
- "ptp1b_protein",
- "syk_protein",
- "thrombin_protein",
- "tnsk2_protein",
- "tyk2_protein",
- ]
+ "cmet_protein",
+ "hif2a_protein",
+ "mcl1_protein",
+ "p38_protein",
+ "ptp1b_protein",
+ "syk_protein",
+ "thrombin_protein",
+ "tnsk2_protein",
+ "tyk2_protein",
+]
_pl_benchmark_url_pattern = (
@@ -62,14 +64,10 @@ def get_test_filename(filename):
PDB_BENCHMARK_LOADERS = {
- name: URLFileLike(url=_pl_benchmark_url_pattern.format(name=name))
- for name in _benchmark_pdb_names
+ name: URLFileLike(url=_pl_benchmark_url_pattern.format(name=name)) for name in _benchmark_pdb_names
}
-PDB_FILE_LOADERS = {
- name: lambda: get_test_filename(name)
- for name in ["181l.pdb"]
-}
+PDB_FILE_LOADERS = {name: lambda: get_test_filename(name) for name in ["181l.pdb"]}
ALL_PDB_LOADERS = dict(**PDB_BENCHMARK_LOADERS, **PDB_FILE_LOADERS)
@@ -82,78 +80,72 @@ def ethane_sdf():
@pytest.fixture
def toluene_mol2_path():
- with importlib.resources.path('gufe.tests.data', 'toluene.mol2') as f:
+ with importlib.resources.path("gufe.tests.data", "toluene.mol2") as f:
yield str(f)
@pytest.fixture
def multi_molecule_sdf():
- fn = 'multi_molecule.sdf'
- with importlib.resources.path('gufe.tests.data', fn) as f:
+ fn = "multi_molecule.sdf"
+ with importlib.resources.path("gufe.tests.data", fn) as f:
yield str(f)
@pytest.fixture
def PDB_181L_path():
- with importlib.resources.path('gufe.tests.data', '181l.pdb') as f:
+ with importlib.resources.path("gufe.tests.data", "181l.pdb") as f:
yield str(f)
@pytest.fixture
def PDB_181L_OpenMMClean_path():
- with importlib.resources.path('gufe.tests.data',
- '181l_openmmClean.pdb') as f:
+ with importlib.resources.path("gufe.tests.data", "181l_openmmClean.pdb") as f:
yield str(f)
@pytest.fixture
def offxml_settings_path():
- with importlib.resources.path('gufe.tests.data', 'offxml_settings.json') as f:
+ with importlib.resources.path("gufe.tests.data", "offxml_settings.json") as f:
yield str(f)
@pytest.fixture
def all_settings_path():
- with importlib.resources.path('gufe.tests.data', 'all_settings.json') as f:
+ with importlib.resources.path("gufe.tests.data", "all_settings.json") as f:
yield str(f)
@pytest.fixture
def PDB_thrombin_path():
- with importlib.resources.path('gufe.tests.data',
- 'thrombin_protein.pdb') as f:
+ with importlib.resources.path("gufe.tests.data", "thrombin_protein.pdb") as f:
yield str(f)
@pytest.fixture
def PDBx_181L_path():
- with importlib.resources.path('gufe.tests.data',
- '181l.cif') as f:
+ with importlib.resources.path("gufe.tests.data", "181l.cif") as f:
yield str(f)
@pytest.fixture
def PDBx_181L_openMMClean_path():
- with importlib.resources.path('gufe.tests.data',
- '181l_openmmClean.cif') as f:
+ with importlib.resources.path("gufe.tests.data", "181l_openmmClean.cif") as f:
yield str(f)
-@pytest.fixture(scope='session')
+@pytest.fixture(scope="session")
def benzene_modifications():
- with importlib.resources.path('gufe.tests.data',
- 'benzene_modifications.sdf') as f:
+ with importlib.resources.path("gufe.tests.data", "benzene_modifications.sdf") as f:
supp = Chem.SDMolSupplier(str(f), removeHs=False)
mols = list(supp)
- return {m.GetProp('_Name'): m for m in mols}
+ return {m.GetProp("_Name"): m for m in mols}
-@pytest.fixture(scope='session')
+@pytest.fixture(scope="session")
def benzene_transforms(benzene_modifications):
- return {k: gufe.SmallMoleculeComponent(v)
- for k, v in benzene_modifications.items()}
+ return {k: gufe.SmallMoleculeComponent(v) for k, v in benzene_modifications.items()}
@pytest.fixture
@@ -168,7 +160,7 @@ def toluene(benzene_modifications):
@pytest.fixture
def phenol(benzene_modifications):
- return gufe.SmallMoleculeComponent(benzene_modifications['phenol'])
+ return gufe.SmallMoleculeComponent(benzene_modifications["phenol"])
@pytest.fixture
@@ -253,9 +245,10 @@ def absolute_transformation(solvated_ligand, solvated_complex):
def complex_equilibrium(solvated_complex):
return gufe.NonTransformation(
solvated_complex,
- protocol=DummyProtocol(settings=DummyProtocol.default_settings())
+ protocol=DummyProtocol(settings=DummyProtocol.default_settings()),
)
+
@pytest.fixture
def benzene_variants_star_map_transformations(
benzene,
@@ -269,26 +262,20 @@ def benzene_variants_star_map_transformations(
solv_comp,
):
- variants = [toluene, phenol, benzonitrile, anisole, benzaldehyde,
- styrene]
+ variants = [toluene, phenol, benzonitrile, anisole, benzaldehyde, styrene]
# define the solvent chemical systems and transformations between
# benzene and the others
solvated_ligands = {}
solvated_ligand_transformations = {}
- solvated_ligands["benzene"] = gufe.ChemicalSystem(
- {"solvent": solv_comp, "ligand": benzene}, name="benzene-solvent"
- )
+ solvated_ligands["benzene"] = gufe.ChemicalSystem({"solvent": solv_comp, "ligand": benzene}, name="benzene-solvent")
for ligand in variants:
solvated_ligands[ligand.name] = gufe.ChemicalSystem(
- {"solvent": solv_comp, "ligand": ligand},
- name=f"{ligand.name}-solvnet"
+ {"solvent": solv_comp, "ligand": ligand}, name=f"{ligand.name}-solvnet"
)
- solvated_ligand_transformations[
- ("benzene", ligand.name)
- ] = gufe.Transformation(
+ solvated_ligand_transformations[("benzene", ligand.name)] = gufe.Transformation(
solvated_ligands["benzene"],
solvated_ligands[ligand.name],
protocol=DummyProtocol(settings=DummyProtocol.default_settings()),
@@ -310,9 +297,7 @@ def benzene_variants_star_map_transformations(
{"protein": prot_comp, "solvent": solv_comp, "ligand": ligand},
name=f"{ligand.name}-complex",
)
- solvated_complex_transformations[
- ("benzene", ligand.name)
- ] = gufe.Transformation(
+ solvated_complex_transformations[("benzene", ligand.name)] = gufe.Transformation(
solvated_complexes["benzene"],
solvated_complexes[ligand.name],
protocol=DummyProtocol(settings=DummyProtocol.default_settings()),
@@ -325,7 +310,8 @@ def benzene_variants_star_map_transformations(
@pytest.fixture
def benzene_variants_star_map(benzene_variants_star_map_transformations):
solvated_ligand_transformations, solvated_complex_transformations = benzene_variants_star_map_transformations
- return gufe.AlchemicalNetwork(solvated_ligand_transformations+solvated_complex_transformations)
+ return gufe.AlchemicalNetwork(solvated_ligand_transformations + solvated_complex_transformations)
+
@pytest.fixture
def benzene_variants_ligand_star_map(benzene_variants_star_map_transformations):
diff --git a/gufe/tests/data/__init__.py b/gufe/tests/data/__init__.py
index 3bf5fbe3..d9960422 100644
--- a/gufe/tests/data/__init__.py
+++ b/gufe/tests/data/__init__.py
@@ -1,2 +1,2 @@
# This code is part of OpenFE and is licensed under the MIT license.
-# For details, see https://github.com/OpenFreeEnergy/gufe
\ No newline at end of file
+# For details, see https://github.com/OpenFreeEnergy/gufe
diff --git a/gufe/tests/data/all_settings.json b/gufe/tests/data/all_settings.json
index 2218fba7..d50c1a79 100644
--- a/gufe/tests/data/all_settings.json
+++ b/gufe/tests/data/all_settings.json
@@ -22,4 +22,4 @@
"redox_potential": null
},
"protocol_settings": null
-}
\ No newline at end of file
+}
diff --git a/gufe/tests/data/ligand_network.graphml b/gufe/tests/data/ligand_network.graphml
index 31fe9e56..d331b2f6 100644
--- a/gufe/tests/data/ligand_network.graphml
+++ b/gufe/tests/data/ligand_network.graphml
@@ -4,13 +4,13 @@
- {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}], [6, 0, 0, false, 0, 0, {}]], "bonds": [[0, 1, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (2, 3), } \n\u0000\u0000\u0000\u0000\u0000\u0000\u00e8\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0090<\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00e8?\u0000\u0000\u0000\u0000\u0000\u0000\u0090\u00bc\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}}
+ {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}, 4], [6, 0, 0, false, 0, 0, {}, 4]], "bonds": [[0, 1, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (2, 3), } \n\u0000\u0000\u0000\u0000\u0000\u0000\u00e8\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0090<\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00e8?\u0000\u0000\u0000\u0000\u0000\u0000\u0090\u00bc\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}}
- {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}], [6, 0, 0, false, 0, 0, {}], [8, 0, 0, false, 0, 0, {}]], "bonds": [[0, 1, 1, 0, {}], [1, 2, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (3, 3), } \n\u00809B.\u00dc\u00c8\u00f4\u00bf\u00f5\u00ff\u00ff\u00ff\u00ff\u00ff\u00cf\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0001\u0000\u0000\u0000\u0000\u0000\u00e0?\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00809B.\u00dc\u00c8\u00f4?\u0006\u0000\u0000\u0000\u0000\u0000\u00d0\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}}
+ {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}, 4], [6, 0, 0, false, 0, 0, {}, 4], [8, 0, 0, false, 0, 0, {}, 4]], "bonds": [[0, 1, 1, 0, {}], [1, 2, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (3, 3), } \n\u00809B.\u00dc\u00c8\u00f4\u00bf\u00f5\u00ff\u00ff\u00ff\u00ff\u00ff\u00cf\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0001\u0000\u0000\u0000\u0000\u0000\u00e0?\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00809B.\u00dc\u00c8\u00f4?\u0006\u0000\u0000\u0000\u0000\u0000\u00d0\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}}
- {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}], [8, 0, 0, false, 0, 0, {}]], "bonds": [[0, 1, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (2, 3), } \n\u0000\u0000\u0000\u0000\u0000\u0000\u00e8\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0090<\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00e8?\u0000\u0000\u0000\u0000\u0000\u0000\u0090\u00bc\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}}
+ {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}, 4], [8, 0, 0, false, 0, 0, {}, 4]], "bonds": [[0, 1, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (2, 3), } \n\u0000\u0000\u0000\u0000\u0000\u0000\u00e8\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0090<\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00e8?\u0000\u0000\u0000\u0000\u0000\u0000\u0090\u00bc\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}}
[[0, 0]]
diff --git a/gufe/tests/dev/serialization_test_templates.py b/gufe/tests/dev/serialization_test_templates.py
index 95d45a85..6d1f7e56 100644
--- a/gufe/tests/dev/serialization_test_templates.py
+++ b/gufe/tests/dev/serialization_test_templates.py
@@ -4,7 +4,8 @@
from rdkit import Chem
from rdkit.Chem import AllChem
-from gufe import SmallMoleculeComponent, LigandNetwork, LigandAtomMapping
+
+from gufe import LigandAtomMapping, LigandNetwork, SmallMoleculeComponent
def mol_from_smiles(smiles: str) -> Chem.Mol:
diff --git a/gufe/tests/storage/test_externalresource.py b/gufe/tests/storage/test_externalresource.py
index 43736991..adfe9170 100644
--- a/gufe/tests/storage/test_externalresource.py
+++ b/gufe/tests/storage/test_externalresource.py
@@ -1,13 +1,12 @@
-import pytest
-import pathlib
import hashlib
import os
+import pathlib
from unittest import mock
+import pytest
+
+from gufe.storage.errors import ChangedExternalResourceError, MissingExternalResourceError
from gufe.storage.externalresource import FileStorage, MemoryStorage
-from gufe.storage.errors import (
- MissingExternalResourceError, ChangedExternalResourceError
-)
# NOTE: Tests for the abstract base are just part of the tests of its
# subclasses
@@ -21,43 +20,46 @@ def file_storage(tmp_path):
* foo.txt : contents "bar"
* with/directory.txt : contents "in a directory"
"""
- with open(tmp_path / 'foo.txt', 'wb') as foo:
- foo.write("bar".encode("utf-8"))
+ with open(tmp_path / "foo.txt", "wb") as foo:
+ foo.write(b"bar")
- inner_dir = tmp_path / 'with'
+ inner_dir = tmp_path / "with"
inner_dir.mkdir()
- with open(inner_dir / 'directory.txt', 'wb') as with_dir:
- with_dir.write("in a directory".encode("utf-8"))
+ with open(inner_dir / "directory.txt", "wb") as with_dir:
+ with_dir.write(b"in a directory")
return FileStorage(tmp_path)
class TestFileStorage:
- @pytest.mark.parametrize('filename, expected', [
- ('foo.txt', True),
- ('notexisting.txt', False),
- ('with/directory.txt', True),
- ])
+ @pytest.mark.parametrize(
+ "filename, expected",
+ [
+ ("foo.txt", True),
+ ("notexisting.txt", False),
+ ("with/directory.txt", True),
+ ],
+ )
def test_exists(self, filename, expected, file_storage):
assert file_storage.exists(filename) == expected
def test_store_bytes(self, file_storage):
fileloc = file_storage.root_dir / "bar.txt"
assert not fileloc.exists()
- as_bytes = "This is bar".encode('utf-8')
+ as_bytes = b"This is bar"
file_storage.store_bytes("bar.txt", as_bytes)
assert fileloc.exists()
- with open(fileloc, 'rb') as f:
+ with open(fileloc, "rb") as f:
assert as_bytes == f.read()
- @pytest.mark.parametrize('nested', [True, False])
+ @pytest.mark.parametrize("nested", [True, False])
def test_store_path(self, file_storage, nested):
orig_file = file_storage.root_dir / ".hidden" / "bar.txt"
orig_file.parent.mkdir()
- as_bytes = "This is bar".encode('utf-8')
- with open(orig_file, 'wb') as f:
+ as_bytes = b"This is bar"
+ with open(orig_file, "wb") as f:
f.write(as_bytes)
nested_dir = "nested" if nested else ""
@@ -67,7 +69,7 @@ def test_store_path(self, file_storage, nested):
file_storage.store_path(fileloc, orig_file)
assert fileloc.exists()
- with open(fileloc, 'rb') as f:
+ with open(fileloc, "rb") as f:
assert as_bytes == f.read()
def test_eq(self, tmp_path):
@@ -82,25 +84,28 @@ def test_delete(self, file_storage):
file_storage.delete("foo.txt")
assert not path.exists()
- @pytest.mark.parametrize('prefix,expected', [
- ("", {'foo.txt', 'foo_dir/a.txt', 'foo_dir/b.txt'}),
- ("foo", {'foo.txt', 'foo_dir/a.txt', 'foo_dir/b.txt'}),
- ("foo_dir/", {'foo_dir/a.txt', 'foo_dir/b.txt'}),
- ("foo_dir/a", {'foo_dir/a.txt'}),
- ("foo_dir/a.txt", {'foo_dir/a.txt'}),
- ("baz", set()),
- ])
+ @pytest.mark.parametrize(
+ "prefix,expected",
+ [
+ ("", {"foo.txt", "foo_dir/a.txt", "foo_dir/b.txt"}),
+ ("foo", {"foo.txt", "foo_dir/a.txt", "foo_dir/b.txt"}),
+ ("foo_dir/", {"foo_dir/a.txt", "foo_dir/b.txt"}),
+ ("foo_dir/a", {"foo_dir/a.txt"}),
+ ("foo_dir/a.txt", {"foo_dir/a.txt"}),
+ ("baz", set()),
+ ],
+ )
def test_iter_contents(self, tmp_path, prefix, expected):
files = [
- 'foo.txt',
- 'foo_dir/a.txt',
- 'foo_dir/b.txt',
+ "foo.txt",
+ "foo_dir/a.txt",
+ "foo_dir/b.txt",
]
for file in files:
path = tmp_path / file
path.parent.mkdir(parents=True, exist_ok=True)
assert not path.exists()
- with open(path, 'wb') as f:
+ with open(path, "wb") as f:
f.write(b"")
storage = FileStorage(tmp_path)
@@ -108,8 +113,7 @@ def test_iter_contents(self, tmp_path, prefix, expected):
assert set(storage.iter_contents(prefix)) == expected
def test_delete_error_not_existing(self, file_storage):
- with pytest.raises(MissingExternalResourceError,
- match="does not exist"):
+ with pytest.raises(MissingExternalResourceError, match="does not exist"):
file_storage.delete("baz.txt")
def test_get_filename(self, file_storage):
@@ -119,7 +123,7 @@ def test_get_filename(self, file_storage):
def test_load_stream(self, file_storage):
with file_storage.load_stream("foo.txt") as f:
- results = f.read().decode('utf-8')
+ results = f.read().decode("utf-8")
assert results == "bar"
@@ -130,25 +134,24 @@ def test_load_stream_error_missing(self, file_storage):
class TestMemoryStorage:
def setup_method(self):
- self.contents = {'path/to/foo.txt': 'bar'.encode('utf-8')}
+ self.contents = {"path/to/foo.txt": b"bar"}
self.storage = MemoryStorage()
self.storage._data = dict(self.contents)
- @pytest.mark.parametrize('expected', [True, False])
+ @pytest.mark.parametrize("expected", [True, False])
def test_exists(self, expected):
path = "path/to/foo.txt" if expected else "path/to/bar.txt"
assert self.storage.exists(path) is expected
def test_delete(self):
# checks internal state
- assert 'path/to/foo.txt' in self.storage._data
- self.storage.delete('path/to/foo.txt')
- assert 'path/to/foo.txt' not in self.storage._data
+ assert "path/to/foo.txt" in self.storage._data
+ self.storage.delete("path/to/foo.txt")
+ assert "path/to/foo.txt" not in self.storage._data
def test_delete_error_not_existing(self):
- with pytest.raises(MissingExternalResourceError,
- match="Unable to delete"):
- self.storage.delete('does/not/exist.txt')
+ with pytest.raises(MissingExternalResourceError, match="Unable to delete"):
+ self.storage.delete("does/not/exist.txt")
def test_store_bytes(self):
storage = MemoryStorage()
@@ -162,7 +165,7 @@ def test_store_path(self, tmp_path):
for label, data in self.contents.items():
path = tmp_path / label
path.parent.mkdir(parents=True, exist_ok=True)
- with open(path, mode='wb') as f:
+ with open(path, mode="wb") as f:
f.write(data)
storage.store_path(label, path)
@@ -174,14 +177,17 @@ def test_eq(self):
assert reference == reference
assert reference != MemoryStorage()
- @pytest.mark.parametrize('prefix,expected', [
- ("", {'foo.txt', 'foo_dir/a.txt', 'foo_dir/b.txt'}),
- ("foo", {'foo.txt', 'foo_dir/a.txt', 'foo_dir/b.txt'}),
- ("foo_dir/", {'foo_dir/a.txt', 'foo_dir/b.txt'}),
- ("foo_dir/a", {'foo_dir/a.txt'}),
- ("foo_dir/a.txt", {'foo_dir/a.txt'}),
- ("baz", set()),
- ])
+ @pytest.mark.parametrize(
+ "prefix,expected",
+ [
+ ("", {"foo.txt", "foo_dir/a.txt", "foo_dir/b.txt"}),
+ ("foo", {"foo.txt", "foo_dir/a.txt", "foo_dir/b.txt"}),
+ ("foo_dir/", {"foo_dir/a.txt", "foo_dir/b.txt"}),
+ ("foo_dir/a", {"foo_dir/a.txt"}),
+ ("foo_dir/a.txt", {"foo_dir/a.txt"}),
+ ("baz", set()),
+ ],
+ )
def test_iter_contents(self, prefix, expected):
storage = MemoryStorage()
storage._data = {
@@ -198,6 +204,6 @@ def test_get_filename(self):
def test_load_stream(self):
path = "path/to/foo.txt"
with self.storage.load_stream(path) as f:
- results = f.read().decode('utf-8')
+ results = f.read().decode("utf-8")
assert results == "bar"
diff --git a/gufe/tests/test_alchemicalnetwork.py b/gufe/tests/test_alchemicalnetwork.py
index 8231c51d..d65995e0 100644
--- a/gufe/tests/test_alchemicalnetwork.py
+++ b/gufe/tests/test_alchemicalnetwork.py
@@ -1,8 +1,8 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
-import pytest
import networkx as nx
+import pytest
from gufe import AlchemicalNetwork, ChemicalSystem, Transformation
@@ -56,16 +56,16 @@ def test_connected_subgraphs_multiple_subgraphs(self, benzene_variants_star_map)
subgraphs = [subgraph for subgraph in alnet.connected_subgraphs()]
- assert set([len(subgraph.nodes) for subgraph in subgraphs]) == {6,7,1}
+ assert {len(subgraph.nodes) for subgraph in subgraphs} == {6, 7, 1}
# which graph has the removed node is not deterministic, so we just
# check that one graph is all-solvent and the other is all-protein
for subgraph in subgraphs:
components = [frozenset(n.components.keys()) for n in subgraph.nodes]
- if {'solvent','protein','ligand'} in components:
- assert set(components) == {frozenset({'solvent','protein','ligand'})}
+ if {"solvent", "protein", "ligand"} in components:
+ assert set(components) == {frozenset({"solvent", "protein", "ligand"})}
else:
- assert set(components) == {frozenset({'solvent','ligand'})}
+ assert set(components) == {frozenset({"solvent", "ligand"})}
def test_connected_subgraphs_one_subgraph(self, benzene_variants_ligand_star_map):
"""Return the same network if it only contains one connected component."""
diff --git a/gufe/tests/test_chemicalsystem.py b/gufe/tests/test_chemicalsystem.py
index 661ab19c..414cd96b 100644
--- a/gufe/tests/test_chemicalsystem.py
+++ b/gufe/tests/test_chemicalsystem.py
@@ -1,62 +1,56 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
-import pytest
import numpy as np
+import pytest
from gufe import ChemicalSystem
from .test_tokenization import GufeTokenizableTestsMixin
+
def test_ligand_construction(solv_comp, toluene_ligand_comp):
# sanity checks on construction
state = ChemicalSystem(
- {'solvent': solv_comp,
- 'ligand': toluene_ligand_comp},
+ {"solvent": solv_comp, "ligand": toluene_ligand_comp},
)
assert len(state.components) == 2
assert len(state) == 2
- assert list(state) == ['solvent', 'ligand']
+ assert list(state) == ["solvent", "ligand"]
- assert state.components['solvent'] == solv_comp
- assert state.components['ligand'] == toluene_ligand_comp
- assert state['solvent'] == solv_comp
- assert state['ligand'] == toluene_ligand_comp
+ assert state.components["solvent"] == solv_comp
+ assert state.components["ligand"] == toluene_ligand_comp
+ assert state["solvent"] == solv_comp
+ assert state["ligand"] == toluene_ligand_comp
def test_complex_construction(prot_comp, solv_comp, toluene_ligand_comp):
# sanity checks on construction
state = ChemicalSystem(
- {'protein': prot_comp,
- 'solvent': solv_comp,
- 'ligand': toluene_ligand_comp},
+ {"protein": prot_comp, "solvent": solv_comp, "ligand": toluene_ligand_comp},
)
assert len(state.components) == 3
assert len(state) == 3
- assert list(state) == ['protein', 'solvent', 'ligand']
+ assert list(state) == ["protein", "solvent", "ligand"]
- assert state.components['protein'] == prot_comp
- assert state.components['solvent'] == solv_comp
- assert state.components['ligand'] == toluene_ligand_comp
- assert state['protein'] == prot_comp
- assert state['solvent'] == solv_comp
- assert state['ligand'] == toluene_ligand_comp
+ assert state.components["protein"] == prot_comp
+ assert state.components["solvent"] == solv_comp
+ assert state.components["ligand"] == toluene_ligand_comp
+ assert state["protein"] == prot_comp
+ assert state["solvent"] == solv_comp
+ assert state["ligand"] == toluene_ligand_comp
def test_hash_and_eq(prot_comp, solv_comp, toluene_ligand_comp):
- c1 = ChemicalSystem({'protein': prot_comp,
- 'solvent': solv_comp,
- 'ligand': toluene_ligand_comp})
+ c1 = ChemicalSystem({"protein": prot_comp, "solvent": solv_comp, "ligand": toluene_ligand_comp})
- c2 = ChemicalSystem({'solvent': solv_comp,
- 'ligand': toluene_ligand_comp,
- 'protein': prot_comp})
+ c2 = ChemicalSystem({"solvent": solv_comp, "ligand": toluene_ligand_comp, "protein": prot_comp})
assert c1 == c2
assert hash(c1) == hash(c2)
@@ -68,8 +62,7 @@ def test_chemical_system_neq_1(solvated_complex, prot_comp):
assert hash(solvated_complex) != hash(prot_comp)
-def test_chemical_system_neq_2(solvated_complex, prot_comp, solv_comp,
- toluene_ligand_comp):
+def test_chemical_system_neq_2(solvated_complex, prot_comp, solv_comp, toluene_ligand_comp):
# names are different
complex2 = ChemicalSystem(
{"protein": prot_comp, "solvent": solv_comp, "ligand": toluene_ligand_comp},
@@ -86,13 +79,10 @@ def test_chemical_system_neq_4(solvated_complex, solvated_ligand):
assert hash(solvated_complex) != hash(solvated_ligand)
-def test_chemical_system_neq_5(solvated_complex, prot_comp, solv_comp,
- phenol_ligand_comp):
+def test_chemical_system_neq_5(solvated_complex, prot_comp, solv_comp, phenol_ligand_comp):
# same component keys, but different components
complex2 = ChemicalSystem(
- {'protein': prot_comp,
- 'solvent': solv_comp,
- 'ligand': phenol_ligand_comp},
+ {"protein": prot_comp, "solvent": solv_comp, "ligand": phenol_ligand_comp},
)
assert solvated_complex != complex2
assert hash(solvated_complex) != hash(complex2)
@@ -123,6 +113,5 @@ class TestChemicalSystem(GufeTokenizableTestsMixin):
@pytest.fixture
def instance(self, solv_comp, toluene_ligand_comp):
return ChemicalSystem(
- {'solvent': solv_comp,
- 'ligand': toluene_ligand_comp},
- )
+ {"solvent": solv_comp, "ligand": toluene_ligand_comp},
+ )
diff --git a/gufe/tests/test_custom_json.py b/gufe/tests/test_custom_json.py
index 9d4eebe1..dbc61ba7 100644
--- a/gufe/tests/test_custom_json.py
+++ b/gufe/tests/test_custom_json.py
@@ -5,17 +5,19 @@
import json
import pathlib
+from uuid import uuid4
import numpy as np
import openff.units
-from openff.units import unit
import pytest
from numpy import testing as npt
-from uuid import uuid4
+from openff.units import unit
+
+from gufe import tokenization
from gufe.custom_codecs import (
BYTES_CODEC,
- NUMPY_CODEC,
NPY_DTYPE_CODEC,
+ NUMPY_CODEC,
OPENFF_QUANTITY_CODEC,
OPENFF_UNIT_CODEC,
PATH_CODEC,
@@ -23,7 +25,6 @@
UUID_CODEC,
)
from gufe.custom_json import JSONSerializerDeserializer, custom_json_factory
-from gufe import tokenization
from gufe.settings import models
@@ -48,14 +49,14 @@ def test_add_existing_codec(self):
assert len(serialization.codecs) == 1
-@pytest.mark.parametrize('obj', [
- np.array([[1.0, 0.0], [2.0, 3.2]]),
- np.float32(1.1)
-])
-@pytest.mark.parametrize('codecs', [
- [BYTES_CODEC, NUMPY_CODEC, NPY_DTYPE_CODEC],
- [NPY_DTYPE_CODEC, BYTES_CODEC, NUMPY_CODEC],
-])
+@pytest.mark.parametrize("obj", [np.array([[1.0, 0.0], [2.0, 3.2]]), np.float32(1.1)])
+@pytest.mark.parametrize(
+ "codecs",
+ [
+ [BYTES_CODEC, NUMPY_CODEC, NPY_DTYPE_CODEC],
+ [NPY_DTYPE_CODEC, BYTES_CODEC, NUMPY_CODEC],
+ ],
+)
def test_numpy_codec_order_roundtrip(obj, codecs):
serialization = JSONSerializerDeserializer(codecs)
serialized = serialization.serializer(obj)
@@ -76,15 +77,15 @@ class CustomJSONCodingTest:
"""
def test_default(self):
- for (obj, dct) in zip(self.objs, self.dcts):
+ for obj, dct in zip(self.objs, self.dcts):
assert self.codec.default(obj) == dct
def test_object_hook(self):
- for (obj, dct) in zip(self.objs, self.dcts):
+ for obj, dct in zip(self.objs, self.dcts):
assert self.codec.object_hook(dct) == obj
def _test_round_trip(self, encoder, decoder):
- for (obj, dct) in zip(self.objs, self.dcts):
+ for obj, dct in zip(self.objs, self.dcts):
json_str = json.dumps(obj, cls=encoder)
reconstructed = json.loads(json_str, cls=decoder)
assert reconstructed == obj
@@ -107,9 +108,20 @@ def test_not_mine(self):
class TestNumpyCoding(CustomJSONCodingTest):
def setup_method(self):
self.codec = NUMPY_CODEC
- self.objs = [np.array([[1.0, 0.0], [2.0, 3.2]]), np.array([1, 0]),
- np.array([1.0, 2.0, 3.0], dtype=np.float32)]
- shapes = [[2, 2], [2,], [3,]]
+ self.objs = [
+ np.array([[1.0, 0.0], [2.0, 3.2]]),
+ np.array([1, 0]),
+ np.array([1.0, 2.0, 3.0], dtype=np.float32),
+ ]
+ shapes = [
+ [2, 2],
+ [
+ 2,
+ ],
+ [
+ 3,
+ ],
+ ]
dtypes = [str(arr.dtype) for arr in self.objs] # may change by system?
byte_reps = [arr.tobytes() for arr in self.objs]
self.dcts = [
@@ -126,13 +138,13 @@ def setup_method(self):
def test_object_hook(self):
# to get custom equality testing for numpy
- for (obj, dct) in zip(self.objs, self.dcts):
+ for obj, dct in zip(self.objs, self.dcts):
reconstructed = self.codec.object_hook(dct)
npt.assert_array_equal(reconstructed, obj)
def test_round_trip(self):
encoder, decoder = custom_json_factory([self.codec, BYTES_CODEC])
- for (obj, dct) in zip(self.objs, self.dcts):
+ for obj, dct in zip(self.objs, self.dcts):
json_str = json.dumps(obj, cls=encoder)
reconstructed = json.loads(json_str, cls=decoder)
npt.assert_array_equal(reconstructed, obj)
@@ -147,15 +159,19 @@ def setup_method(self):
# Note that np.float64 is treated as a float by the
# default json encode (and so returns a float not a numpy
# object).
- self.objs = [np.bool_(True), np.float16(1.0), np.float32(1.0),
- np.complex128(1.0),
- np.clongdouble(1.0), np.uint64(1)]
+ self.objs = [
+ np.bool_(True),
+ np.float16(1.0),
+ np.float32(1.0),
+ np.complex128(1.0),
+ np.clongdouble(1.0),
+ np.uint64(1),
+ ]
dtypes = [str(a.dtype) for a in self.objs]
byte_reps = [a.tobytes() for a in self.objs]
# Overly complicated extraction of the class name
# to deal with the bool_ -> bool dtype class name problem
- classes = [str(a.__class__).split("'")[1].split('.')[1]
- for a in self.objs]
+ classes = [str(a.__class__).split("'")[1].split(".")[1] for a in self.objs]
self.dcts = [
{
":is_custom:": True,
@@ -221,19 +237,23 @@ def setup_method(self):
],
"small_molecule_forcefield": "openff-2.1.1",
"nonbonded_method": "PME",
- "nonbonded_cutoff": {':is_custom:': True,
- 'magnitude': 1.0,
- 'pint_unit_registry': 'openff_units',
- 'unit': 'nanometer'},
+ "nonbonded_cutoff": {
+ ":is_custom:": True,
+ "magnitude": 1.0,
+ "pint_unit_registry": "openff_units",
+ "unit": "nanometer",
+ },
},
"thermo_settings": {
"__class__": "ThermoSettings",
"__module__": "gufe.settings.models",
":is_custom:": True,
- "temperature": {":is_custom:": True,
- "magnitude": 300.0,
- "pint_unit_registry": "openff_units",
- "unit": "kelvin"},
+ "temperature": {
+ ":is_custom:": True,
+ "magnitude": 300.0,
+ "pint_unit_registry": "openff_units",
+ "unit": "kelvin",
+ },
"pressure": None,
"ph": None,
"redox_potential": None,
@@ -275,9 +295,7 @@ def setup_method(self):
def test_openff_quantity_array_roundtrip():
- thing = unit.Quantity.from_list([
- (i + 1.0)*unit.kelvin for i in range(10)
- ])
+ thing = unit.Quantity.from_list([(i + 1.0) * unit.kelvin for i in range(10)])
dumped = json.dumps(thing, cls=tokenization.JSON_HANDLER.encoder)
@@ -305,9 +323,7 @@ def setup_method(self):
class TestUUIDCodec(CustomJSONCodingTest):
def setup_method(self):
self.codec = UUID_CODEC
- self.objs = [
- uuid4()
- ]
+ self.objs = [uuid4()]
self.dcts = [
{
":is_custom:": True,
diff --git a/gufe/tests/test_ligand_network.py b/gufe/tests/test_ligand_network.py
index 9b3be8c4..a7116a8e 100644
--- a/gufe/tests/test_ligand_network.py
+++ b/gufe/tests/test_ligand_network.py
@@ -1,17 +1,17 @@
# This code is part of gufe and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
-from typing import Iterable, NamedTuple
-import pytest
import importlib.resources
-import gufe
-from gufe.tests.test_protocol import DummyProtocol
-from gufe import SmallMoleculeComponent, LigandNetwork, LigandAtomMapping
+from collections.abc import Iterable
+from typing import NamedTuple
+import pytest
+from networkx import NetworkXError
from openff.units import unit
-
from rdkit import Chem
-from networkx import NetworkXError
+import gufe
+from gufe import LigandAtomMapping, LigandNetwork, SmallMoleculeComponent
+from gufe.tests.test_protocol import DummyProtocol
from .test_tokenization import GufeTokenizableTestsMixin
@@ -22,8 +22,10 @@ def mol_from_smiles(smi):
return m
+
class _NetworkTestContainer(NamedTuple):
"""Container to facilitate network testing"""
+
network: LigandNetwork
nodes: Iterable[SmallMoleculeComponent]
edges: Iterable[LigandAtomMapping]
@@ -33,8 +35,8 @@ class _NetworkTestContainer(NamedTuple):
@pytest.fixture
def ligandnetwork_graphml():
- with importlib.resources.path('gufe.tests.data', 'ligand_network.graphml') as file:
- with open(file, 'r') as f:
+ with importlib.resources.path("gufe.tests.data", "ligand_network.graphml") as file:
+ with open(file) as f:
yield f.read()
@@ -96,24 +98,28 @@ def singleton_node_network(mols, std_edges):
n_edges=3,
)
+
@pytest.fixture
def real_molecules_network(benzene, phenol, toluene):
"""Small network with full mappings"""
# benzene to phenol
bp_mapping = {i: i for i in range(10)}
- bp_mapping.update({10: 12, 11:11})
+ bp_mapping.update({10: 12, 11: 11})
# benzene to toluene
bt_mapping = {i: i + 4 for i in range(10)}
bt_mapping.update({10: 2, 11: 14})
- network = gufe.LigandNetwork([
- gufe.LigandAtomMapping(benzene, toluene, bt_mapping),
- gufe.LigandAtomMapping(benzene, phenol, bp_mapping),
- ])
+ network = gufe.LigandNetwork(
+ [
+ gufe.LigandAtomMapping(benzene, toluene, bt_mapping),
+ gufe.LigandAtomMapping(benzene, phenol, bp_mapping),
+ ]
+ )
return network
-@pytest.fixture(params=['simple', 'doubled_edge', 'singleton_node'])
+
+@pytest.fixture(params=["simple", "doubled_edge", "singleton_node"])
def network_container(
request,
simple_network,
@@ -122,9 +128,9 @@ def network_container(
):
"""Fixture to allow parameterization of the network test"""
network_dct = {
- 'simple': simple_network,
- 'doubled_edge': doubled_edge_network,
- 'singleton_node': singleton_node_network,
+ "simple": simple_network,
+ "doubled_edge": doubled_edge_network,
+ "singleton_node": singleton_node_network,
}
return network_dct[request.param]
@@ -140,7 +146,7 @@ def instance(self, simple_network):
def test_node_type(self, network_container):
n = network_container.network
- assert all((isinstance(node, SmallMoleculeComponent) for node in n.nodes))
+ assert all(isinstance(node, SmallMoleculeComponent) for node in n.nodes)
def test_graph(self, network_container):
# The NetworkX graph that comes from the ``.graph`` property should
@@ -150,21 +156,19 @@ def test_graph(self, network_container):
assert set(graph.nodes) == set(network_container.nodes)
assert len(graph.edges) == network_container.n_edges
# extract the AtomMappings from the nx edges
- mappings = [
- atommapping for _, _, atommapping in graph.edges.data('object')
- ]
+ mappings = [atommapping for _, _, atommapping in graph.edges.data("object")]
assert set(mappings) == set(network_container.edges)
# ensure LigandAtomMapping stored in nx edge is consistent with nx edge
- for mol1, mol2, atommapping in graph.edges.data('object'):
+ for mol1, mol2, atommapping in graph.edges.data("object"):
assert atommapping.componentA == mol1
assert atommapping.componentB == mol2
def test_graph_annotations(self, mols, std_edges):
mol1, mol2, mol3 = mols
edge12, edge23, edge13 = std_edges
- annotated = edge12.with_annotations({'foo': 'bar'})
+ annotated = edge12.with_annotations({"foo": "bar"})
network = LigandNetwork([annotated, edge23, edge13])
- assert network.graph[mol1][mol2][0]['foo'] == 'bar'
+ assert network.graph[mol1][mol2][0]["foo"] == "bar"
def test_graph_immutability(self, mols, network_container):
# The NetworkX graph that comes from that ``.graph`` property should
@@ -285,7 +289,7 @@ def test_is_connected(self, simple_network):
def test_is_not_connected(self, singleton_node_network):
assert not singleton_node_network.network.is_connected()
- @pytest.mark.parametrize('with_cofactor', [True, False])
+ @pytest.mark.parametrize("with_cofactor", [True, False])
def test_to_rbfe_alchemical_network(
self,
real_molecules_network,
@@ -297,25 +301,22 @@ def test_to_rbfe_alchemical_network(
# obviously, this particular set of ligands with this particular
# protein makes no sense, but we should still be able to set it up
if with_cofactor:
- others = {'cofactor': request.getfixturevalue('styrene')}
+ others = {"cofactor": request.getfixturevalue("styrene")}
else:
others = {}
protocol = DummyProtocol(DummyProtocol.default_settings())
rbfe = real_molecules_network.to_rbfe_alchemical_network(
- solvent=solv_comp,
- protein=prot_comp,
- protocol=protocol,
- **others
+ solvent=solv_comp, protein=prot_comp, protocol=protocol, **others
)
expected_names = {
- 'easy_rbfe_benzene_solvent_toluene_solvent',
- 'easy_rbfe_benzene_complex_toluene_complex',
- 'easy_rbfe_benzene_solvent_phenol_solvent',
- 'easy_rbfe_benzene_complex_phenol_complex',
+ "easy_rbfe_benzene_solvent_toluene_solvent",
+ "easy_rbfe_benzene_complex_toluene_complex",
+ "easy_rbfe_benzene_solvent_phenol_solvent",
+ "easy_rbfe_benzene_complex_phenol_complex",
}
- names = set(edge.name for edge in rbfe.edges)
+ names = {edge.name for edge in rbfe.edges}
assert names == expected_names
assert len(rbfe.edges) == 2 * len(real_molecules_network.edges)
@@ -324,35 +325,29 @@ def test_to_rbfe_alchemical_network(
compsA = edge.stateA.components
compsB = edge.stateB.components
- if 'solvent' in edge.name:
- labels = {'solvent', 'ligand'}
- elif 'complex' in edge.name:
- labels = {'solvent', 'ligand', 'protein'}
+ if "solvent" in edge.name:
+ labels = {"solvent", "ligand"}
+ elif "complex" in edge.name:
+ labels = {"solvent", "ligand", "protein"}
if with_cofactor:
- labels.add('cofactor')
+ labels.add("cofactor")
else: # -no-cov-
- raise RuntimeError("Something went weird in testing. Unable "
- f"to get leg for edge {edge}")
+ raise RuntimeError("Something went weird in testing. Unable " f"to get leg for edge {edge}")
assert set(compsA) == labels
assert set(compsB) == labels
- assert compsA['ligand'] != compsB['ligand']
- assert compsA['ligand'].name == 'benzene'
- assert compsA['solvent'] == compsB['solvent']
+ assert compsA["ligand"] != compsB["ligand"]
+ assert compsA["ligand"].name == "benzene"
+ assert compsA["solvent"] == compsB["solvent"]
# for things that might not always exist, use .get
- assert compsA.get('protein') == compsB.get('protein')
- assert compsA.get('cofactor') == compsB.get('cofactor')
+ assert compsA.get("protein") == compsB.get("protein")
+ assert compsA.get("cofactor") == compsB.get("cofactor")
assert isinstance(edge.mapping, gufe.ComponentMapping)
assert edge.mapping in real_molecules_network.edges
- def test_to_rbfe_alchemical_network_autoname_false(
- self,
- real_molecules_network,
- prot_comp,
- solv_comp
- ):
+ def test_to_rbfe_alchemical_network_autoname_false(self, real_molecules_network, prot_comp, solv_comp):
rbfe = real_molecules_network.to_rbfe_alchemical_network(
solvent=solv_comp,
protein=prot_comp,
@@ -364,12 +359,7 @@ def test_to_rbfe_alchemical_network_autoname_false(
for sys in [edge.stateA, edge.stateB]:
assert sys.name == ""
- def test_to_rbfe_alchemical_network_autoname_true(
- self,
- real_molecules_network,
- prot_comp,
- solv_comp
- ):
+ def test_to_rbfe_alchemical_network_autoname_true(self, real_molecules_network, prot_comp, solv_comp):
rbfe = real_molecules_network.to_rbfe_alchemical_network(
solvent=solv_comp,
protein=prot_comp,
@@ -378,33 +368,28 @@ def test_to_rbfe_alchemical_network_autoname_true(
autoname_prefix="",
)
expected_names = {
- 'benzene_complex_toluene_complex',
- 'benzene_solvent_toluene_solvent',
- 'benzene_complex_phenol_complex',
- 'benzene_solvent_phenol_solvent',
+ "benzene_complex_toluene_complex",
+ "benzene_solvent_toluene_solvent",
+ "benzene_complex_phenol_complex",
+ "benzene_solvent_phenol_solvent",
}
- names = set(edge.name for edge in rbfe.edges)
+ names = {edge.name for edge in rbfe.edges}
assert names == expected_names
@pytest.mark.xfail # method removed and on hold for now
- def test_to_rhfe_alchemical_network(self, real_molecules_network,
- solv_comp):
+ def test_to_rhfe_alchemical_network(self, real_molecules_network, solv_comp):
others = {}
protocol = DummyProtocol(DummyProtocol.default_settings())
- rhfe = real_molecules_network.to_rhfe_alchemical_network(
- solvent=solv_comp,
- protocol=protocol,
- **others
- )
+ rhfe = real_molecules_network.to_rhfe_alchemical_network(solvent=solv_comp, protocol=protocol, **others)
expected_names = {
- 'easy_rhfe_benzene_vacuum_toluene_vacuum',
- 'easy_rhfe_benzene_solvent_toluene_solvent',
- 'easy_rhfe_benzene_vacuum_phenol_vacuum',
- 'easy_rhfe_benzene_solvent_phenol_solvent',
+ "easy_rhfe_benzene_vacuum_toluene_vacuum",
+ "easy_rhfe_benzene_solvent_toluene_solvent",
+ "easy_rhfe_benzene_vacuum_phenol_vacuum",
+ "easy_rhfe_benzene_solvent_phenol_solvent",
}
- names = set(edge.name for edge in rhfe.edges)
+ names = {edge.name for edge in rhfe.edges}
assert names == expected_names
assert len(rhfe.edges) == 2 * len(real_molecules_network.edges)
@@ -414,25 +399,24 @@ def test_to_rhfe_alchemical_network(self, real_molecules_network,
compsA = edge.stateA.components
compsB = edge.stateB.components
- if 'vacuum' in edge.name:
- labels = {'ligand'}
- elif 'solvent' in edge.name:
- labels = {'ligand', 'solvent'}
+ if "vacuum" in edge.name:
+ labels = {"ligand"}
+ elif "solvent" in edge.name:
+ labels = {"ligand", "solvent"}
else: # -no-cov-
- raise RuntimeError("Something went weird in testing. Unable "
- f"to get leg for edge {edge}")
+ raise RuntimeError("Something went weird in testing. Unable " f"to get leg for edge {edge}")
labels |= set(others)
assert set(compsA) == labels
assert set(compsB) == labels
- assert compsA['ligand'] != compsB['ligand']
- assert compsA['ligand'].name == 'benzene'
- assert compsA.get('solvent') == compsB.get('solvent')
+ assert compsA["ligand"] != compsB["ligand"]
+ assert compsA["ligand"].name == "benzene"
+ assert compsA.get("solvent") == compsB.get("solvent")
- assert list(edge.mapping) == ['ligand']
- assert edge.mapping['ligand'] in real_molecules_network.edges
+ assert list(edge.mapping) == ["ligand"]
+ assert edge.mapping["ligand"] in real_molecules_network.edges
def test_empty_ligand_network(mols):
diff --git a/gufe/tests/test_ligandatommapping.py b/gufe/tests/test_ligandatommapping.py
index 93809ac7..63ae076e 100644
--- a/gufe/tests/test_ligandatommapping.py
+++ b/gufe/tests/test_ligandatommapping.py
@@ -1,12 +1,14 @@
# This code is part of gufe and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
import importlib
-import pytest
-import pathlib
import json
+import pathlib
+
import numpy as np
-from rdkit import Chem
+import pytest
from openff.units import unit
+from rdkit import Chem
+
import gufe
from gufe import LigandAtomMapping, SmallMoleculeComponent
@@ -20,7 +22,7 @@ def mol_from_smiles(smiles: str) -> gufe.SmallMoleculeComponent:
return gufe.SmallMoleculeComponent(m)
-@pytest.fixture(scope='session')
+@pytest.fixture(scope="session")
def simple_mapping():
"""Disappearing oxygen on end
@@ -28,15 +30,15 @@ def simple_mapping():
C C
"""
- molA = mol_from_smiles('CCO')
- molB = mol_from_smiles('CC')
+ molA = mol_from_smiles("CCO")
+ molB = mol_from_smiles("CC")
m = LigandAtomMapping(molA, molB, componentA_to_componentB={0: 0, 1: 1})
return m
-@pytest.fixture(scope='session')
+@pytest.fixture(scope="session")
def other_mapping():
"""Disappearing middle carbon
@@ -44,8 +46,8 @@ def other_mapping():
C C
"""
- molA = mol_from_smiles('CCO')
- molB = mol_from_smiles('CC')
+ molA = mol_from_smiles("CCO")
+ molB = mol_from_smiles("CC")
m = LigandAtomMapping(molA, molB, componentA_to_componentB={0: 0, 2: 1})
@@ -54,55 +56,82 @@ def other_mapping():
@pytest.fixture
def annotated_simple_mapping(simple_mapping):
- mapping = LigandAtomMapping(simple_mapping.componentA,
- simple_mapping.componentB,
- simple_mapping.componentA_to_componentB,
- annotations={'foo': 'bar'})
+ mapping = LigandAtomMapping(
+ simple_mapping.componentA,
+ simple_mapping.componentB,
+ simple_mapping.componentA_to_componentB,
+ annotations={"foo": "bar"},
+ )
return mapping
-@pytest.fixture(scope='session')
+@pytest.fixture(scope="session")
def benzene_maps():
MAPS = {
- 'phenol': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6,
- 7: 7, 8: 8, 9: 9, 10: 12, 11: 11},
- 'anisole': {0: 5, 1: 6, 2: 7, 3: 8, 4: 9, 5: 10,
- 6: 11, 7: 12, 8: 13, 9: 14, 10: 2, 11: 15}}
+ "phenol": {
+ 0: 0,
+ 1: 1,
+ 2: 2,
+ 3: 3,
+ 4: 4,
+ 5: 5,
+ 6: 6,
+ 7: 7,
+ 8: 8,
+ 9: 9,
+ 10: 12,
+ 11: 11,
+ },
+ "anisole": {
+ 0: 5,
+ 1: 6,
+ 2: 7,
+ 3: 8,
+ 4: 9,
+ 5: 10,
+ 6: 11,
+ 7: 12,
+ 8: 13,
+ 9: 14,
+ 10: 2,
+ 11: 15,
+ },
+ }
return MAPS
-@pytest.fixture(scope='session')
+@pytest.fixture(scope="session")
def benzene_phenol_mapping(benzene_transforms, benzene_maps):
- molA = SmallMoleculeComponent(benzene_transforms['benzene'].to_rdkit())
- molB = SmallMoleculeComponent(benzene_transforms['phenol'].to_rdkit())
- m = LigandAtomMapping(molA, molB, benzene_maps['phenol'])
+ molA = SmallMoleculeComponent(benzene_transforms["benzene"].to_rdkit())
+ molB = SmallMoleculeComponent(benzene_transforms["phenol"].to_rdkit())
+ m = LigandAtomMapping(molA, molB, benzene_maps["phenol"])
return m
-@pytest.fixture(scope='session')
+@pytest.fixture(scope="session")
def benzene_anisole_mapping(benzene_transforms, benzene_maps):
- molA = SmallMoleculeComponent(benzene_transforms['benzene'].to_rdkit())
- molB = SmallMoleculeComponent(benzene_transforms['anisole'].to_rdkit())
- m = LigandAtomMapping(molA, molB, benzene_maps['anisole'])
+ molA = SmallMoleculeComponent(benzene_transforms["benzene"].to_rdkit())
+ molB = SmallMoleculeComponent(benzene_transforms["anisole"].to_rdkit())
+ m = LigandAtomMapping(molA, molB, benzene_maps["anisole"])
return m
-@pytest.fixture(scope='session')
+@pytest.fixture(scope="session")
def atom_mapping_basic_test_files():
# a dict of {filenames.strip(mol2): SmallMoleculeComponent} for a simple
# set of ligands
files = {}
for f in [
- '1,3,7-trimethylnaphthalene',
- '1-butyl-4-methylbenzene',
- '2,6-dimethylnaphthalene',
- '2-methyl-6-propylnaphthalene',
- '2-methylnaphthalene',
- '2-naftanol',
- 'methylcyclohexane',
- 'toluene']:
- with importlib.resources.path('gufe.tests.data.lomap_basic',
- f + '.mol2') as fn:
+ "1,3,7-trimethylnaphthalene",
+ "1-butyl-4-methylbenzene",
+ "2,6-dimethylnaphthalene",
+ "2-methyl-6-propylnaphthalene",
+ "2-methylnaphthalene",
+ "2-naftanol",
+ "methylcyclohexane",
+ "toluene",
+ ]:
+ with importlib.resources.path("gufe.tests.data.lomap_basic", f + ".mol2") as fn:
mol = Chem.MolFromMol2File(str(fn), removeHs=False)
files[f] = SmallMoleculeComponent(mol, name=f)
@@ -120,15 +149,25 @@ def test_atommapping_usage(simple_mapping):
def test_mapping_inversion(benzene_phenol_mapping):
assert benzene_phenol_mapping.componentB_to_componentA == {
- 0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 11: 11,
- 12: 10
+ 0: 0,
+ 1: 1,
+ 2: 2,
+ 3: 3,
+ 4: 4,
+ 5: 5,
+ 6: 6,
+ 7: 7,
+ 8: 8,
+ 9: 9,
+ 11: 11,
+ 12: 10,
}
def test_mapping_distances(benzene_phenol_mapping):
d = benzene_phenol_mapping.get_distances()
- ref = [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.34005502, 0.]
+ ref = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.34005502, 0.0]
assert isinstance(d, np.ndarray)
for i, r in zip(d, ref):
@@ -137,15 +176,27 @@ def test_mapping_distances(benzene_phenol_mapping):
def test_uniques(atom_mapping_basic_test_files):
mapping = LigandAtomMapping(
- componentA=atom_mapping_basic_test_files['methylcyclohexane'],
- componentB=atom_mapping_basic_test_files['toluene'],
- componentA_to_componentB={
- 0: 6, 1: 7, 2: 8, 3: 9, 4: 10, 5: 11, 6: 12
- }
+ componentA=atom_mapping_basic_test_files["methylcyclohexane"],
+ componentB=atom_mapping_basic_test_files["toluene"],
+ componentA_to_componentB={0: 6, 1: 7, 2: 8, 3: 9, 4: 10, 5: 11, 6: 12},
)
- assert list(mapping.componentA_unique) == [7, 8, 9, 10, 11, 12, 13, 14, 15,
- 16, 17, 18, 19, 20]
+ assert list(mapping.componentA_unique) == [
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ ]
assert list(mapping.componentB_unique) == [0, 1, 2, 3, 4, 5, 13, 14]
@@ -166,25 +217,23 @@ def test_atommapping_hash(simple_mapping, other_mapping):
def test_draw_mapping_cairo(tmpdir, simple_mapping):
with tmpdir.as_cwd():
- simple_mapping.draw_to_file('test.png')
- filed = pathlib.Path('test.png')
+ simple_mapping.draw_to_file("test.png")
+ filed = pathlib.Path("test.png")
assert filed.exists()
def test_draw_mapping_svg(tmpdir, other_mapping):
with tmpdir.as_cwd():
d2d = Chem.Draw.rdMolDraw2D.MolDraw2DSVG(600, 300, 300, 300)
- other_mapping.draw_to_file('test.svg', d2d=d2d)
- filed = pathlib.Path('test.svg')
+ other_mapping.draw_to_file("test.svg", d2d=d2d)
+ filed = pathlib.Path("test.svg")
assert filed.exists()
class TestLigandAtomMappingSerialization:
- def test_deserialize_roundtrip(self, benzene_phenol_mapping,
- benzene_anisole_mapping):
+ def test_deserialize_roundtrip(self, benzene_phenol_mapping, benzene_anisole_mapping):
- roundtrip = LigandAtomMapping.from_dict(
- benzene_phenol_mapping.to_dict())
+ roundtrip = LigandAtomMapping.from_dict(benzene_phenol_mapping.to_dict())
assert roundtrip == benzene_phenol_mapping
@@ -195,10 +244,10 @@ def test_deserialize_roundtrip(self, benzene_phenol_mapping,
def test_file_roundtrip(self, benzene_phenol_mapping, tmpdir):
with tmpdir.as_cwd():
- with open('tmpfile.json', 'w') as f:
+ with open("tmpfile.json", "w") as f:
f.write(json.dumps(benzene_phenol_mapping.to_dict()))
- with open('tmpfile.json', 'r') as f:
+ with open("tmpfile.json") as f:
d = json.load(f)
assert isinstance(d, dict)
@@ -207,27 +256,26 @@ def test_file_roundtrip(self, benzene_phenol_mapping, tmpdir):
assert roundtrip == benzene_phenol_mapping
-def test_annotated_atommapping_hash_eq(simple_mapping,
- annotated_simple_mapping):
+def test_annotated_atommapping_hash_eq(simple_mapping, annotated_simple_mapping):
assert annotated_simple_mapping != simple_mapping
assert hash(annotated_simple_mapping) != hash(simple_mapping)
def test_annotation_immutability(annotated_simple_mapping):
annot1 = annotated_simple_mapping.annotations
- annot1['foo'] = 'baz'
+ annot1["foo"] = "baz"
annot2 = annotated_simple_mapping.annotations
assert annot1 != annot2
- assert annot2 == {'foo': 'bar'}
+ assert annot2 == {"foo": "bar"}
def test_with_annotations(simple_mapping, annotated_simple_mapping):
- new_annot = simple_mapping.with_annotations({'foo': 'bar'})
+ new_annot = simple_mapping.with_annotations({"foo": "bar"})
assert new_annot == annotated_simple_mapping
def test_with_fancy_annotations(simple_mapping):
- m = simple_mapping.with_annotations({'thing': 4.0 * unit.nanometer})
+ m = simple_mapping.with_annotations({"thing": 4.0 * unit.nanometer})
assert m.key
@@ -240,36 +288,28 @@ class TestLigandAtomMappingBoundsChecks:
@pytest.fixture
def molA(self):
# 9 atoms
- return mol_from_smiles('CCO')
+ return mol_from_smiles("CCO")
@pytest.fixture
def molB(self):
# 11 atoms
- return mol_from_smiles('CCC')
+ return mol_from_smiles("CCC")
def test_too_large_A(self, molA, molB):
with pytest.raises(ValueError, match="invalid index for ComponentA"):
- LigandAtomMapping(componentA=molA,
- componentB=molB,
- componentA_to_componentB={9: 5})
+ LigandAtomMapping(componentA=molA, componentB=molB, componentA_to_componentB={9: 5})
def test_too_small_A(self, molA, molB):
with pytest.raises(ValueError, match="invalid index for ComponentA"):
- LigandAtomMapping(componentA=molA,
- componentB=molB,
- componentA_to_componentB={-2: 5})
+ LigandAtomMapping(componentA=molA, componentB=molB, componentA_to_componentB={-2: 5})
def test_too_large_B(self, molA, molB):
with pytest.raises(ValueError, match="invalid index for ComponentB"):
- LigandAtomMapping(componentA=molA,
- componentB=molB,
- componentA_to_componentB={5: 11})
+ LigandAtomMapping(componentA=molA, componentB=molB, componentA_to_componentB={5: 11})
def test_too_small_B(self, molA, molB):
with pytest.raises(ValueError, match="invalid index for ComponentB"):
- LigandAtomMapping(componentA=molA,
- componentB=molB,
- componentA_to_componentB={5: -1})
+ LigandAtomMapping(componentA=molA, componentB=molB, componentA_to_componentB={5: -1})
class TestLigandAtomMapping(GufeTokenizableTestsMixin):
diff --git a/gufe/tests/test_mapping.py b/gufe/tests/test_mapping.py
index cb8c7370..620c2f55 100644
--- a/gufe/tests/test_mapping.py
+++ b/gufe/tests/test_mapping.py
@@ -1,17 +1,21 @@
# This code is part of gufe and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
-import gufe
-from gufe import AtomMapping
import pytest
+import gufe
+from gufe import AtomMapping
from .test_tokenization import GufeTokenizableTestsMixin
class ExampleMapping(AtomMapping):
- def __init__(self, molA: gufe.SmallMoleculeComponent,
- molB: gufe.SmallMoleculeComponent, mapping):
+ def __init__(
+ self,
+ molA: gufe.SmallMoleculeComponent,
+ molB: gufe.SmallMoleculeComponent,
+ mapping,
+ ):
super().__init__(molA, molB)
self._mapping = mapping
@@ -21,9 +25,9 @@ def _defaults(cls):
def _to_dict(self):
return {
- 'molA': self._componentA,
- 'molB': self._componentB,
- 'mapping': self._mapping,
+ "molA": self._componentA,
+ "molB": self._componentB,
+ "mapping": self._mapping,
}
@classmethod
@@ -37,12 +41,10 @@ def componentB_to_componentA(self):
return {v: k for k, v in self._mapping}
def componentA_unique(self):
- return (i for i in range(self._molA.to_rdkit().GetNumAtoms())
- if i not in self._mapping)
+ return (i for i in range(self._molA.to_rdkit().GetNumAtoms()) if i not in self._mapping)
def componentB_unique(self):
- return (i for i in range(self._molB.to_rdkit().GetNumAtoms())
- if i not in self._mapping.values())
+ return (i for i in range(self._molB.to_rdkit().GetNumAtoms()) if i not in self._mapping.values())
class TestMappingAbstractClass(GufeTokenizableTestsMixin):
diff --git a/gufe/tests/test_mapping_visualization.py b/gufe/tests/test_mapping_visualization.py
index 049ed626..756ad13a 100644
--- a/gufe/tests/test_mapping_visualization.py
+++ b/gufe/tests/test_mapping_visualization.py
@@ -1,18 +1,21 @@
-import pytest
-from unittest import mock
import inspect
+from unittest import mock
+import pytest
from rdkit import Chem
import gufe
from gufe.visualization.mapping_visualization import (
- _match_elements, _get_unique_bonds_and_atoms, draw_mapping,
- draw_one_molecule_mapping, draw_unhighlighted_molecule
+ _get_unique_bonds_and_atoms,
+ _match_elements,
+ draw_mapping,
+ draw_one_molecule_mapping,
+ draw_unhighlighted_molecule,
)
# default colors currently used
-_HIGHLIGHT_COLOR = (220/255, 50/255, 32/255, 1)
-_CHANGED_ELEMENTS_COLOR = (0, 90/255, 181/255, 1)
+_HIGHLIGHT_COLOR = (220 / 255, 50 / 255, 32 / 255, 1)
+_CHANGED_ELEMENTS_COLOR = (0, 90 / 255, 181 / 255, 1)
def bound_args(func, args, kwargs):
@@ -38,11 +41,14 @@ def bound_args(func, args, kwargs):
return bound.arguments
-@pytest.mark.parametrize("at1, idx1, at2, idx2, response", [
- ["N", 0, "C", 0, False],
- ["C", 0, "C", 0, True],
- ["COC", 1, "NOC", 1, True],
- ["COON", 2, "COC", 2, False]]
+@pytest.mark.parametrize(
+ "at1, idx1, at2, idx2, response",
+ [
+ ["N", 0, "C", 0, False],
+ ["C", 0, "C", 0, True],
+ ["COC", 1, "NOC", 1, True],
+ ["COON", 2, "COC", 2, False],
+ ],
)
def test_match_elements(at1, idx1, at2, idx2, response):
mol1 = Chem.MolFromSmiles(at1)
@@ -52,56 +58,109 @@ def test_match_elements(at1, idx1, at2, idx2, response):
assert retval == response
-@pytest.fixture(scope='module')
+@pytest.fixture(scope="module")
def maps():
MAPS = {
- 'phenol': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6,
- 7: 7, 8: 8, 9: 9, 10: 12, 11: 11},
- 'anisole': {0: 5, 1: 6, 2: 7, 3: 8, 4: 9, 5: 10,
- 6: 11, 7: 12, 8: 13, 9: 14, 10: 2, 11: 15}}
+ "phenol": {
+ 0: 0,
+ 1: 1,
+ 2: 2,
+ 3: 3,
+ 4: 4,
+ 5: 5,
+ 6: 6,
+ 7: 7,
+ 8: 8,
+ 9: 9,
+ 10: 12,
+ 11: 11,
+ },
+ "anisole": {
+ 0: 5,
+ 1: 6,
+ 2: 7,
+ 3: 8,
+ 4: 9,
+ 5: 10,
+ 6: 11,
+ 7: 12,
+ 8: 13,
+ 9: 14,
+ 10: 2,
+ 11: 15,
+ },
+ }
return MAPS
-@pytest.fixture(scope='module')
+@pytest.fixture(scope="module")
def benzene_phenol_mapping(benzene_transforms, maps):
- mol1 = benzene_transforms['benzene'].to_rdkit()
- mol2 = benzene_transforms['phenol'].to_rdkit()
- mapping = maps['phenol']
+ mol1 = benzene_transforms["benzene"].to_rdkit()
+ mol2 = benzene_transforms["phenol"].to_rdkit()
+ mapping = maps["phenol"]
return mapping, mol1, mol2
-@pytest.mark.parametrize('molname, atoms, elems, bond_changes, bond_deletions', [
- ['phenol', {10, }, {12, }, {10, }, {12, }],
- ['anisole', {0, 1, 3, 4}, {2, }, {13, }, {0, 1, 2, 3}]
-])
-def test_benzene_to_phenol_uniques(molname, atoms, elems, bond_changes, bond_deletions,
- benzene_transforms, maps):
- mol1 = benzene_transforms['benzene']
+@pytest.mark.parametrize(
+ "molname, atoms, elems, bond_changes, bond_deletions",
+ [
+ [
+ "phenol",
+ {
+ 10,
+ },
+ {
+ 12,
+ },
+ {
+ 10,
+ },
+ {
+ 12,
+ },
+ ],
+ [
+ "anisole",
+ {0, 1, 3, 4},
+ {
+ 2,
+ },
+ {
+ 13,
+ },
+ {0, 1, 2, 3},
+ ],
+ ],
+)
+def test_benzene_to_phenol_uniques(molname, atoms, elems, bond_changes, bond_deletions, benzene_transforms, maps):
+ mol1 = benzene_transforms["benzene"]
mol2 = benzene_transforms[molname]
mapping = maps[molname]
- uniques = _get_unique_bonds_and_atoms(mapping,
- mol1.to_rdkit(), mol2.to_rdkit())
+ uniques = _get_unique_bonds_and_atoms(mapping, mol1.to_rdkit(), mol2.to_rdkit())
# The benzene perturbations don't change
# no unique atoms in benzene
- assert uniques['atoms'] == set()
+ assert uniques["atoms"] == set()
# H->O
- assert uniques['elements'] == {10, }
+ assert uniques["elements"] == {
+ 10,
+ }
# One bond involved
- assert uniques['bond_changes'] == {10, }
+ assert uniques["bond_changes"] == {
+ 10,
+ }
# invert and check the molB uniques
inv_map = {v: k for k, v in mapping.items()}
- uniques = _get_unique_bonds_and_atoms(inv_map,
- mol2.to_rdkit(), mol1.to_rdkit())
+ uniques = _get_unique_bonds_and_atoms(inv_map, mol2.to_rdkit(), mol1.to_rdkit())
- assert uniques['atoms'] == atoms
- assert uniques['elements'] == elems
- assert uniques['bond_changes'] == bond_changes
- assert uniques['bond_deletions'] == bond_deletions
+ assert uniques["atoms"] == atoms
+ assert uniques["elements"] == elems
+ assert uniques["bond_changes"] == bond_changes
+ assert uniques["bond_deletions"] == bond_deletions
@mock.patch("gufe.visualization.mapping_visualization._draw_molecules", autospec=True)
@@ -112,20 +171,20 @@ def test_draw_mapping(mock_func, benzene_phenol_mapping):
draw_mapping(mapping, mol1, mol2)
mock_func.assert_called_once()
- args = bound_args(mock_func, mock_func.call_args.args,
- mock_func.call_args.kwargs)
- assert args['mols'] == [mol1, mol2]
- assert args['atoms_list'] == [{10}, {10, 12}]
- assert args['bonds_list'] == [{10}, {10, 12}]
- assert args['atom_colors'] == [{10: _CHANGED_ELEMENTS_COLOR},
- {12: _CHANGED_ELEMENTS_COLOR}]
- assert args['highlight_color'] == _HIGHLIGHT_COLOR
-
-
-@pytest.mark.parametrize('inverted', [True, False])
+ args = bound_args(mock_func, mock_func.call_args.args, mock_func.call_args.kwargs)
+ assert args["mols"] == [mol1, mol2]
+ assert args["atoms_list"] == [{10}, {10, 12}]
+ assert args["bonds_list"] == [{10}, {10, 12}]
+ assert args["atom_colors"] == [
+ {10: _CHANGED_ELEMENTS_COLOR},
+ {12: _CHANGED_ELEMENTS_COLOR},
+ ]
+ assert args["highlight_color"] == _HIGHLIGHT_COLOR
+
+
+@pytest.mark.parametrize("inverted", [True, False])
@mock.patch("gufe.visualization.mapping_visualization._draw_molecules", autospec=True)
-def test_draw_one_molecule_mapping(mock_func, benzene_phenol_mapping,
- inverted):
+def test_draw_one_molecule_mapping(mock_func, benzene_phenol_mapping, inverted):
# ensure that draw_one_molecule_mapping passes the desired parameters to
# our internal _draw_molecules method
mapping, mol1, mol2 = benzene_phenol_mapping
@@ -143,30 +202,28 @@ def test_draw_one_molecule_mapping(mock_func, benzene_phenol_mapping,
draw_one_molecule_mapping(mapping, mol1, mol2)
mock_func.assert_called_once()
- args = bound_args(mock_func, mock_func.call_args.args,
- mock_func.call_args.kwargs)
+ args = bound_args(mock_func, mock_func.call_args.args, mock_func.call_args.kwargs)
- assert args['mols'] == [mol1]
- assert args['atoms_list'] == atoms_list
- assert args['bonds_list'] == bonds_list
- assert args['atom_colors'] == atom_colors
- assert args['highlight_color'] == _HIGHLIGHT_COLOR
+ assert args["mols"] == [mol1]
+ assert args["atoms_list"] == atoms_list
+ assert args["bonds_list"] == bonds_list
+ assert args["atom_colors"] == atom_colors
+ assert args["highlight_color"] == _HIGHLIGHT_COLOR
@mock.patch("gufe.visualization.mapping_visualization._draw_molecules", autospec=True)
def test_draw_unhighlighted_molecule(mock_func, benzene_transforms):
# ensure that draw_unhighlighted_molecule passes the desired parameters
# to our internal _draw_molecules method
- mol = benzene_transforms['benzene'].to_rdkit()
+ mol = benzene_transforms["benzene"].to_rdkit()
draw_unhighlighted_molecule(mol)
mock_func.assert_called_once()
- args = bound_args(mock_func, mock_func.call_args.args,
- mock_func.call_args.kwargs)
- assert args['mols'] == [mol]
- assert args['atoms_list'] == [[]]
- assert args['bonds_list'] == [[]]
- assert args['atom_colors'] == [{}]
+ args = bound_args(mock_func, mock_func.call_args.args, mock_func.call_args.kwargs)
+ assert args["mols"] == [mol]
+ assert args["atoms_list"] == [[]]
+ assert args["bonds_list"] == [[]]
+ assert args["atom_colors"] == [{}]
# technically, we don't care what the highlight color is, so no
# assertion on that
@@ -186,4 +243,4 @@ def test_draw_one_molecule_integration_smoke(benzene_phenol_mapping):
def test_draw_unhighlighted_molecule_integration_smoke(benzene_transforms):
# integration test/smoke test to catch errors if the upstream drawing
# code changes
- draw_unhighlighted_molecule(benzene_transforms['benzene'].to_rdkit())
+ draw_unhighlighted_molecule(benzene_transforms["benzene"].to_rdkit())
diff --git a/gufe/tests/test_models.py b/gufe/tests/test_models.py
index 4fd67bf0..887cbdd7 100644
--- a/gufe/tests/test_models.py
+++ b/gufe/tests/test_models.py
@@ -6,14 +6,10 @@
import json
-from openff.units import unit
import pytest
+from openff.units import unit
-from gufe.settings.models import (
- OpenMMSystemGeneratorFFSettings,
- Settings,
- ThermoSettings,
-)
+from gufe.settings.models import OpenMMSystemGeneratorFFSettings, Settings, ThermoSettings
def test_model_schema():
@@ -45,11 +41,16 @@ def test_default_settings():
my_settings.schema_json(indent=2)
-@pytest.mark.parametrize('value,good', [
- ('parsnips', False), # shouldn't be allowed
- ('hbonds', True), ('hangles', True), ('allbonds', True), # allowed options
- ('HBonds', True), # check case insensitivity
-])
+@pytest.mark.parametrize(
+ "value,good",
+ [
+ ("parsnips", False), # shouldn't be allowed
+ ("hbonds", True),
+ ("hangles", True),
+ ("allbonds", True), # allowed options
+ ("HBonds", True), # check case insensitivity
+ ],
+)
def test_invalid_constraint(value, good):
if good:
s = OpenMMSystemGeneratorFFSettings(constraints=value)
diff --git a/gufe/tests/test_molhashing.py b/gufe/tests/test_molhashing.py
index cea50cee..33f3a7f7 100644
--- a/gufe/tests/test_molhashing.py
+++ b/gufe/tests/test_molhashing.py
@@ -1,9 +1,10 @@
-import pytest
-from gufe.molhashing import serialize_numpy, deserialize_numpy
import numpy as np
+import pytest
+
+from gufe.molhashing import deserialize_numpy, serialize_numpy
-@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int'])
+@pytest.mark.parametrize("dtype", ["float32", "float64", "int"])
def test_numpy_serialization_cycle(dtype):
arr = np.array([1, 2], dtype=dtype)
ser = serialize_numpy(arr)
diff --git a/gufe/tests/test_proteincomponent.py b/gufe/tests/test_proteincomponent.py
index 63d5b406..3efce664 100644
--- a/gufe/tests/test_proteincomponent.py
+++ b/gufe/tests/test_proteincomponent.py
@@ -1,20 +1,19 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
-import os
-import io
import copy
+import io
+import os
+
import pytest
+from numpy.testing import assert_almost_equal
+from openmm import unit
+from openmm.app import pdbfile
from rdkit import Chem
from gufe import ProteinComponent
-from .test_tokenization import GufeTokenizableTestsMixin
-
-from openmm.app import pdbfile
-from openmm import unit
-from numpy.testing import assert_almost_equal
-
from .conftest import ALL_PDB_LOADERS
+from .test_tokenization import GufeTokenizableTestsMixin
@pytest.fixture
@@ -34,13 +33,13 @@ def assert_same_pdb_lines(in_file_path, out_file_path):
if hasattr(in_file_path, "readlines"):
in_file = in_file_path
else:
- in_file = open(in_file_path, "r")
+ in_file = open(in_file_path)
if isinstance(out_file_path, io.StringIO):
out_file = out_file_path
must_close = False
else:
- out_file = open(out_file_path, mode='r')
+ out_file = open(out_file_path)
must_close = True
in_lines = in_file.readlines()
@@ -49,10 +48,8 @@ def assert_same_pdb_lines(in_file_path, out_file_path):
if must_close:
out_file.close()
- in_lines = [l for l in in_lines
- if not l.startswith(('REMARK', 'CRYST', '# Created with'))]
- out_lines = [l for l in out_lines
- if not l.startswith(('REMARK', 'CRYST', '# Created with'))]
+ in_lines = [l for l in in_lines if not l.startswith(("REMARK", "CRYST", "# Created with"))]
+ out_lines = [l for l in out_lines if not l.startswith(("REMARK", "CRYST", "# Created with"))]
assert in_lines == out_lines
@@ -94,7 +91,7 @@ def instance(self, PDB_181L_path):
return self.cls.from_pdb_file(PDB_181L_path, name="Steve")
# From
- @pytest.mark.parametrize('in_pdb_path', ALL_PDB_LOADERS.keys())
+ @pytest.mark.parametrize("in_pdb_path", ALL_PDB_LOADERS.keys())
def test_from_pdb_file(self, in_pdb_path):
in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]()
p = self.cls.from_pdb_file(in_pdb_io, name="Steve")
@@ -125,8 +122,7 @@ def test_to_rdkit(self, PDB_181L_path):
assert isinstance(rdkitmol, Chem.Mol)
assert rdkitmol.GetNumAtoms() == 2639
- def _test_file_output(self, input_path, output_path, input_type,
- output_func):
+ def _test_file_output(self, input_path, output_path, input_type, output_func):
if input_type == "filename":
inp = str(output_path)
elif input_type == "Path":
@@ -134,7 +130,7 @@ def _test_file_output(self, input_path, output_path, input_type,
elif input_type == "StringIO":
inp = io.StringIO()
elif input_type == "TextIOWrapper":
- inp = open(output_path, mode='w')
+ inp = open(output_path, mode="w")
output_func(inp)
@@ -145,13 +141,10 @@ def _test_file_output(self, input_path, output_path, input_type,
if input_type == "TextIOWrapper":
inp.close()
- assert_same_pdb_lines(in_file_path=str(input_path),
- out_file_path=output_path)
+ assert_same_pdb_lines(in_file_path=str(input_path), out_file_path=output_path)
- @pytest.mark.parametrize('input_type', ['filename', 'Path', 'StringIO',
- 'TextIOWrapper'])
- def test_to_pdbx_file(self, PDBx_181L_openMMClean_path, tmp_path,
- input_type):
+ @pytest.mark.parametrize("input_type", ["filename", "Path", "StringIO", "TextIOWrapper"])
+ def test_to_pdbx_file(self, PDBx_181L_openMMClean_path, tmp_path, input_type):
p = self.cls.from_pdbx_file(str(PDBx_181L_openMMClean_path), name="Bob")
out_file_name = "tmp_181L_pdbx.cif"
out_file = tmp_path / out_file_name
@@ -160,28 +153,26 @@ def test_to_pdbx_file(self, PDBx_181L_openMMClean_path, tmp_path,
input_path=PDBx_181L_openMMClean_path,
output_path=out_file,
input_type=input_type,
- output_func=p.to_pdbx_file
+ output_func=p.to_pdbx_file,
)
- @pytest.mark.parametrize('input_type', ['filename', 'Path', 'StringIO',
- 'TextIOWrapper'])
- def test_to_pdb_input_types(self, PDB_181L_OpenMMClean_path, tmp_path,
- input_type):
+ @pytest.mark.parametrize("input_type", ["filename", "Path", "StringIO", "TextIOWrapper"])
+ def test_to_pdb_input_types(self, PDB_181L_OpenMMClean_path, tmp_path, input_type):
p = self.cls.from_pdb_file(str(PDB_181L_OpenMMClean_path), name="Bob")
self._test_file_output(
input_path=PDB_181L_OpenMMClean_path,
output_path=tmp_path / "tmp_181L.pdb",
input_type=input_type,
- output_func=p.to_pdb_file
+ output_func=p.to_pdb_file,
)
- @pytest.mark.parametrize('in_pdb_path', ALL_PDB_LOADERS.keys())
+ @pytest.mark.parametrize("in_pdb_path", ALL_PDB_LOADERS.keys())
def test_to_pdb_round_trip(self, in_pdb_path, tmp_path):
in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]()
p = self.cls.from_pdb_file(in_pdb_io, name="Wuff")
- out_file_name = "tmp_"+in_pdb_path+".pdb"
+ out_file_name = "tmp_" + in_pdb_path + ".pdb"
out_file = tmp_path / out_file_name
p.to_pdb_file(str(out_file))
@@ -190,7 +181,7 @@ def test_to_pdb_round_trip(self, in_pdb_path, tmp_path):
# generate openMM reference file:
openmm_pdb = pdbfile.PDBFile(ref_in_pdb_io)
- out_ref_file_name = "tmp_"+in_pdb_path+"_openmm_ref.pdb"
+ out_ref_file_name = "tmp_" + in_pdb_path + "_openmm_ref.pdb"
out_ref_file = tmp_path / out_ref_file_name
pdbfile.PDBFile.writeFile(openmm_pdb.topology, openmm_pdb.positions, file=open(str(out_ref_file), "w"))
@@ -213,10 +204,10 @@ def test_dummy_from_dict(self, PDB_181L_OpenMMClean_path):
assert p == p2
# parametrize
- @pytest.mark.parametrize('in_pdb_path', ALL_PDB_LOADERS.keys())
+ @pytest.mark.parametrize("in_pdb_path", ALL_PDB_LOADERS.keys())
def test_to_openmm_positions(self, in_pdb_path):
in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]()
- ref_in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]()
+ ref_in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]()
openmm_pdb = pdbfile.PDBFile(ref_in_pdb_io)
openmm_pos = openmm_pdb.positions
@@ -230,10 +221,10 @@ def test_to_openmm_positions(self, in_pdb_path):
assert_almost_equal(actual=v1, desired=v2, decimal=6)
# parametrize
- @pytest.mark.parametrize('in_pdb_path', ALL_PDB_LOADERS.keys())
+ @pytest.mark.parametrize("in_pdb_path", ALL_PDB_LOADERS.keys())
def test_to_openmm_topology(self, in_pdb_path):
- in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]()
- ref_in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]()
+ in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]()
+ ref_in_pdb_io = ALL_PDB_LOADERS[in_pdb_path]()
openmm_pdb = pdbfile.PDBFile(ref_in_pdb_io)
openmm_top = openmm_pdb.topology
diff --git a/gufe/tests/test_protocol.py b/gufe/tests/test_protocol.py
index 4c44418f..50fb8c0f 100644
--- a/gufe/tests/test_protocol.py
+++ b/gufe/tests/test_protocol.py
@@ -2,29 +2,29 @@
# For details, see https://github.com/OpenFreeEnergy/gufe
import datetime
import itertools
-from openff.units import unit
-from typing import Optional, Iterable, List, Dict, Any, Union
-from collections import defaultdict
import pathlib
+from collections import defaultdict
+from collections.abc import Iterable, Sized
+from typing import Any, Dict, List, Optional, Union
-import pytest
import networkx as nx
import numpy as np
+import pytest
+from openff.units import unit
import gufe
+from gufe import settings
from gufe.chemicalsystem import ChemicalSystem
from gufe.mapping import ComponentMapping
-from gufe import settings
from gufe.protocols import (
Protocol,
ProtocolDAG,
- ProtocolUnit,
- ProtocolResult,
ProtocolDAGResult,
- ProtocolUnitResult,
+ ProtocolResult,
+ ProtocolUnit,
ProtocolUnitFailure,
+ ProtocolUnitResult,
)
-
from gufe.protocols.protocoldag import execute_DAG
from .test_tokenization import GufeTokenizableTestsMixin
@@ -41,15 +41,15 @@ def _execute(ctx, *, settings, stateA, stateB, mapping, start, **inputs):
class SimulationUnit(ProtocolUnit):
@staticmethod
def _execute(ctx, *, initialization, **inputs):
- output = [initialization.outputs['log']]
+ output = [initialization.outputs["log"]]
output.append("running_md_{}".format(inputs["window"]))
return dict(
log=output,
window=inputs["window"],
- key_result=(100 - (inputs["window"] - 10)**2),
+ key_result=(100 - (inputs["window"] - 10) ** 2),
scratch=ctx.scratch,
- shared=ctx.shared
+ shared=ctx.shared,
)
@@ -57,15 +57,12 @@ class FinishUnit(ProtocolUnit):
@staticmethod
def _execute(ctx, *, simulations, **inputs):
- output = [s.outputs['log'] for s in simulations]
+ output = [s.outputs["log"] for s in simulations]
output.append("assembling_results")
- key_results = {str(s.inputs['window']): s.outputs['key_result'] for s in simulations}
+ key_results = {str(s.inputs["window"]): s.outputs["key_result"] for s in simulations}
- return dict(
- log=output,
- key_results=key_results
- )
+ return dict(log=output, key_results=key_results)
class DummySpecificSettings(settings.Settings):
@@ -78,7 +75,7 @@ def get_estimate(self):
# product of neighboring simulation window `key_result`s
dgs = []
- for sample in self.data['key_results']:
+ for sample in self.data["key_results"]:
windows = sorted(sample.keys())
dg = 0
for i, j in zip(windows[:-1], windows[1:]):
@@ -88,8 +85,7 @@ def get_estimate(self):
return np.mean(dg)
- def get_uncertainty(self):
- ...
+ def get_uncertainty(self): ...
class DummyProtocol(Protocol):
@@ -112,16 +108,16 @@ def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
- mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]]=None,
+ mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None,
extends: Optional[ProtocolDAGResult] = None,
- ) -> List[ProtocolUnit]:
+ ) -> list[ProtocolUnit]:
# rip apart `extends` if needed to feed into `InitializeUnit`
if extends is not None:
# this is an example; wouldn't want to pass in whole ProtocolDAGResult into
# any ProtocolUnits below, since this could create dependency hell;
# instead, extract what's needed from it for starting point here
- starting_point = extends.protocol_unit_results[-1].outputs['key_results']
+ starting_point = extends.protocol_unit_results[-1].outputs["key_results"]
else:
starting_point = None
@@ -133,10 +129,11 @@ def _create(
stateB=stateB,
mapping=mapping,
start=starting_point,
- some_dict={'a': 2, 'b': 12})
+ some_dict={"a": 2, "b": 12},
+ )
# create several units that would each run an independent simulation
- simulations: List[ProtocolUnit] = [
+ simulations: list[ProtocolUnit] = [
SimulationUnit(settings=self.settings, name=f"sim {i}", window=i, initialization=alpha)
for i in range(self.settings.n_repeats) # type: ignore
]
@@ -147,16 +144,14 @@ def _create(
# return all `ProtocolUnit`s we created
return [alpha, *simulations, omega]
- def _gather(
- self, protocol_dag_results: Iterable[ProtocolDAGResult]
- ) -> Dict[str, Any]:
+ def _gather(self, protocol_dag_results: Iterable[ProtocolDAGResult]) -> dict[str, Any]:
outputs = defaultdict(list)
for pdr in protocol_dag_results:
for pur in pdr.terminal_protocol_unit_results:
if pur.name == "the end":
- outputs['logs'].append(pur.outputs['log'])
- outputs['key_results'].append(pur.outputs['key_results'])
+ outputs["logs"].append(pur.outputs["log"])
+ outputs["key_results"].append(pur.outputs["key_results"])
return dict(outputs)
@@ -164,7 +159,7 @@ def _gather(
class BrokenSimulationUnit(SimulationUnit):
@staticmethod
def _execute(ctx, **inputs):
- raise ValueError("I have failed my mission", {'data': 'lol'})
+ raise ValueError("I have failed my mission", {"data": "lol"})
class BrokenProtocol(DummyProtocol):
@@ -172,7 +167,7 @@ def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
- mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]]=None,
+ mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None,
extends: Optional[ProtocolDAGResult] = None,
) -> list[ProtocolUnit]:
@@ -186,12 +181,19 @@ def _create(
)
# create several units that would each run an independent simulation
- simulations: List[ProtocolUnit] = [
+ simulations: list[ProtocolUnit] = [
SimulationUnit(settings=self.settings, name=f"sim {i}", window=i, initialization=alpha) for i in range(21)
]
# introduce a broken ProtocolUnit
- simulations.append(BrokenSimulationUnit(settings=self.settings, window=21, name="problem child", initialization=alpha))
+ simulations.append(
+ BrokenSimulationUnit(
+ settings=self.settings,
+ window=21,
+ name="problem child",
+ initialization=alpha,
+ )
+ )
# gather results from simulations, finalize outputs
omega = FinishUnit(settings=self.settings, name="the end", simulations=simulations)
@@ -213,18 +215,19 @@ def instance(self):
def protocol_dag(self, solvated_ligand, vacuum_ligand, tmpdir):
protocol = DummyProtocol(settings=DummyProtocol.default_settings())
dag = protocol.create(
- stateA=solvated_ligand, stateB=vacuum_ligand, name="a dummy run",
+ stateA=solvated_ligand,
+ stateB=vacuum_ligand,
+ name="a dummy run",
mapping=None,
)
with tmpdir.as_cwd():
- shared = pathlib.Path('shared')
+ shared = pathlib.Path("shared")
shared.mkdir(parents=True)
-
- scratch = pathlib.Path('scratch')
+
+ scratch = pathlib.Path("scratch")
scratch.mkdir(parents=True)
- dagresult: ProtocolDAGResult = execute_DAG(
- dag, shared_basedir=shared, scratch_basedir=scratch)
+ dagresult: ProtocolDAGResult = execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch)
return protocol, dag, dagresult
@@ -232,18 +235,21 @@ def protocol_dag(self, solvated_ligand, vacuum_ligand, tmpdir):
def protocol_dag_broken(self, solvated_ligand, vacuum_ligand, tmpdir):
protocol = BrokenProtocol(settings=BrokenProtocol.default_settings())
dag = protocol.create(
- stateA=solvated_ligand, stateB=vacuum_ligand, name="a broken dummy run",
+ stateA=solvated_ligand,
+ stateB=vacuum_ligand,
+ name="a broken dummy run",
mapping=None,
)
with tmpdir.as_cwd():
- shared = pathlib.Path('shared')
+ shared = pathlib.Path("shared")
shared.mkdir(parents=True)
-
- scratch = pathlib.Path('scratch')
+
+ scratch = pathlib.Path("scratch")
scratch.mkdir(parents=True)
dagfailure: ProtocolDAGResult = execute_DAG(
- dag, shared_basedir=shared, scratch_basedir=scratch, raise_error=False)
+ dag, shared_basedir=shared, scratch_basedir=scratch, raise_error=False
+ )
return protocol, dag, dagfailure
@@ -257,24 +263,24 @@ def test_dag_execute(self, protocol_dag):
assert finishresult.name == "the end"
# gather SimulationUnits
- simulationresults = [dagresult.unit_to_result(pu)
- for pu in dagresult.protocol_units
- if isinstance(pu, SimulationUnit)]
+ simulationresults = [
+ dagresult.unit_to_result(pu) for pu in dagresult.protocol_units if isinstance(pu, SimulationUnit)
+ ]
# check that we have dependency information in results
- assert set(finishresult.inputs['simulations']) == {u for u in simulationresults}
+ assert set(finishresult.inputs["simulations"]) == {u for u in simulationresults}
# check that we have as many units as we expect in resulting graph
assert len(dagresult.graph) == 23
-
+
# check that each simulation has its own shared directory
- assert len(set(i.outputs['shared'] for i in simulationresults)) == len(simulationresults)
+ assert len({i.outputs["shared"] for i in simulationresults}) == len(simulationresults)
# check that each simulation has its own scratch directory
- assert len(set(i.outputs['scratch'] for i in simulationresults)) == len(simulationresults)
+ assert len({i.outputs["scratch"] for i in simulationresults}) == len(simulationresults)
# check that shared and scratch not the same for each simulation
- assert all([i.outputs['scratch'] != i.outputs['shared'] for i in simulationresults])
+ assert all([i.outputs["scratch"] != i.outputs["shared"] for i in simulationresults])
def test_terminal_units(self, protocol_dag):
prot, dag, res = protocol_dag
@@ -283,7 +289,7 @@ def test_terminal_units(self, protocol_dag):
assert len(finals) == 1
assert isinstance(finals[0], ProtocolUnitResult)
- assert finals[0].name == 'the end'
+ assert finals[0].name == "the end"
def test_dag_execute_failure(self, protocol_dag_broken):
protocol, dag, dagfailure = protocol_dag_broken
@@ -297,7 +303,7 @@ def test_dag_execute_failure(self, protocol_dag_broken):
assert failed_units[0].name == "problem child"
# parse exception arguments
- assert failed_units[0].exception[1][1]['data'] == "lol"
+ assert failed_units[0].exception[1][1]["data"] == "lol"
assert isinstance(failed_units[0], ProtocolUnitFailure)
succeeded_units = dagfailure.protocol_unit_results
@@ -307,18 +313,25 @@ def test_dag_execute_failure(self, protocol_dag_broken):
def test_dag_execute_failure_raise_error(self, solvated_ligand, vacuum_ligand, tmpdir):
protocol = BrokenProtocol(settings=BrokenProtocol.default_settings())
dag = protocol.create(
- stateA=solvated_ligand, stateB=vacuum_ligand, name="a broken dummy run",
+ stateA=solvated_ligand,
+ stateB=vacuum_ligand,
+ name="a broken dummy run",
mapping=None,
)
with tmpdir.as_cwd():
- shared = pathlib.Path('shared')
+ shared = pathlib.Path("shared")
shared.mkdir(parents=True)
-
- scratch = pathlib.Path('scratch')
+
+ scratch = pathlib.Path("scratch")
scratch.mkdir(parents=True)
with pytest.raises(ValueError, match="I have failed my mission"):
- execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch, raise_error=True)
+ execute_DAG(
+ dag,
+ shared_basedir=shared,
+ scratch_basedir=scratch,
+ raise_error=True,
+ )
def test_create_execute_gather(self, protocol_dag):
protocol, dag, dagresult = protocol_dag
@@ -328,28 +341,51 @@ def test_create_execute_gather(self, protocol_dag):
# gather aggregated results of interest
protocolresult = protocol.gather([dagresult])
- assert len(protocolresult.data['logs']) == 1
- assert len(protocolresult.data['logs'][0]) == 21 + 1
+ assert protocolresult.n_protocol_dag_results == 1
+ assert len(protocolresult.data["logs"]) == 1
+ assert len(protocolresult.data["logs"][0]) == 21 + 1
assert protocolresult.get_estimate() == 95500.0
+ def test_gather_infinite_iterable_guardrail(self, protocol_dag):
+ protocol, dag, dagresult = protocol_dag
+
+ assert dagresult.ok()
+
+ # we want an infinite generator, but one that would actually stop early in case
+ # the guardrail doesn't work, but the type system doesn't know that
+ def infinite_generator():
+ while True:
+ yield dag
+ break
+
+ gen = infinite_generator()
+ assert isinstance(gen, Iterable)
+ assert not isinstance(gen, Sized)
+
+ with pytest.raises(ValueError, match="`protocol_dag_results` must implement `__len__`"):
+ protocol.gather(infinite_generator())
+
def test_deprecation_warning_on_dict_mapping(self, instance, vacuum_ligand, solvated_ligand):
- lig = solvated_ligand.components['ligand']
+ lig = solvated_ligand.components["ligand"]
+
mapping = gufe.LigandAtomMapping(lig, lig, componentA_to_componentB={})
- with pytest.warns(DeprecationWarning,
- match="mapping input as a dict is deprecated"):
- instance.create(stateA=solvated_ligand, stateB=vacuum_ligand,
- mapping={'ligand': mapping})
+ with pytest.warns(DeprecationWarning, match="mapping input as a dict is deprecated"):
+ instance.create(
+ stateA=solvated_ligand,
+ stateB=vacuum_ligand,
+ mapping={"ligand": mapping},
+ )
class ProtocolDAGTestsMixin(GufeTokenizableTestsMixin):
-
+
def test_protocol_units(self, instance):
# ensure that protocol units are given in-order based on DAG
# dependencies
checked = []
for pu in instance.protocol_units:
- assert set(pu.dependencies).issubset(checked)
+ assert set(pu.dependencies).issubset(checked)
checked.append(pu)
def test_graph(self, instance):
@@ -394,7 +430,7 @@ def instance(self, protocol_dag):
def test_protocol_unit_results(self, instance: ProtocolDAGResult):
# ensure that protocolunitresults are given in-order based on DAG
# dependencies
- checked: List[Union[ProtocolUnitResult, ProtocolUnitFailure]] = []
+ checked: list[Union[ProtocolUnitResult, ProtocolUnitFailure]] = []
for pur in instance.protocol_unit_results:
assert set(pur.dependencies).issubset(checked)
checked.append(pur)
@@ -446,8 +482,7 @@ def test_protocol_unit_failures(self, instance: ProtocolDAGResult):
# protocolunitfailures should have no dependents
for puf in instance.protocol_unit_failures:
- assert all([puf not in pu.dependencies
- for pu in instance.protocol_unit_results])
+ assert all([puf not in pu.dependencies for pu in instance.protocol_unit_results])
for node in instance.result_graph.nodes:
with pytest.raises(KeyError):
@@ -464,10 +499,10 @@ def test_protocol_unit_failure_traceback(self, instance: ProtocolDAGResult):
class TestProtocolUnit(GufeTokenizableTestsMixin):
cls = SimulationUnit
repr = None
-
+
@pytest.fixture
def instance(self, vacuum_ligand, solvated_ligand):
-
+
# convert protocol inputs into starting points for independent simulations
alpha = InitializeUnit(
name="the beginning",
@@ -476,9 +511,9 @@ def instance(self, vacuum_ligand, solvated_ligand):
stateB=solvated_ligand,
mapping=None,
start=None,
- some_dict={'a': 2, 'b': 12},
+ some_dict={"a": 2, "b": 12},
)
-
+
return SimulationUnit(name=f"simulation", initialization=alpha)
def test_key_stable(self, instance):
@@ -489,20 +524,21 @@ def test_key_stable(self, instance):
class NoDepUnit(ProtocolUnit):
@staticmethod
- def _execute(ctx, **inputs) -> Dict[str, Any]:
- return {'local': inputs['val'] ** 2}
+ def _execute(ctx, **inputs) -> dict[str, Any]:
+ return {"local": inputs["val"] ** 2}
class NoDepResults(ProtocolResult):
def get_estimate(self):
- return sum(self.data['vals'])
+ return sum(self.data["vals"])
def get_uncertainty(self):
- return len(self.data['vals'])
+ return len(self.data["vals"])
class NoDepsProtocol(Protocol):
"""A protocol without dependencies"""
+
result_cls = NoDepResults
@classmethod
@@ -514,20 +550,21 @@ def _default_settings(cls):
return settings.Settings.get_defaults()
def _create(
- self,
- stateA: ChemicalSystem,
- stateB: ChemicalSystem,
- mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None,
- extends: Optional[ProtocolDAGResult] = None,
- ) -> List[ProtocolUnit]:
- return [NoDepUnit(settings=self.settings,
- val=i)
- for i in range(3)]
+ self,
+ stateA: ChemicalSystem,
+ stateB: ChemicalSystem,
+ mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None,
+ extends: Optional[ProtocolDAGResult] = None,
+ ) -> list[ProtocolUnit]:
+ return [NoDepUnit(settings=self.settings, val=i) for i in range(3)]
def _gather(self, dag_results):
return {
- 'vals': list(itertools.chain.from_iterable(
- (d.outputs['local'] for d in dag.protocol_unit_results) for dag in dag_results)),
+ "vals": list(
+ itertools.chain.from_iterable(
+ (d.outputs["local"] for d in dag.protocol_unit_results) for dag in dag_results
+ )
+ ),
}
@@ -539,19 +576,20 @@ def protocol(self):
@pytest.fixture()
def dag(self, protocol):
return protocol.create(
- stateA=ChemicalSystem(components={'solvent': gufe.SolventComponent(positive_ion='Na')}),
- stateB=ChemicalSystem(components={'solvent': gufe.SolventComponent(positive_ion='Li')}),
- mapping=None)
+ stateA=ChemicalSystem(components={"solvent": gufe.SolventComponent(positive_ion="Na")}),
+ stateB=ChemicalSystem(components={"solvent": gufe.SolventComponent(positive_ion="Li")}),
+ mapping=None,
+ )
def test_create(self, dag):
assert len(dag.protocol_units) == 3
def test_gather(self, protocol, dag, tmpdir):
with tmpdir.as_cwd():
- shared = pathlib.Path('shared')
+ shared = pathlib.Path("shared")
shared.mkdir(parents=True)
-
- scratch = pathlib.Path('scratch')
+
+ scratch = pathlib.Path("scratch")
scratch.mkdir(parents=True)
dag_result = execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch)
@@ -565,10 +603,10 @@ def test_gather(self, protocol, dag, tmpdir):
def test_terminal_units(self, protocol, dag, tmpdir):
with tmpdir.as_cwd():
- shared = pathlib.Path('shared')
+ shared = pathlib.Path("shared")
shared.mkdir(parents=True)
-
- scratch = pathlib.Path('scratch')
+
+ scratch = pathlib.Path("scratch")
scratch.mkdir(parents=True)
# we have no dependencies, so this should be all three Unit results
@@ -581,6 +619,7 @@ def test_terminal_units(self, protocol, dag, tmpdir):
class TestProtocolDAGResult:
"""tests for combinations of failures and successes in a DAGResult"""
+
@staticmethod
@pytest.fixture()
def units() -> list[ProtocolUnit]:
@@ -593,10 +632,16 @@ def successes(units) -> list[ProtocolUnitResult]:
t1 = datetime.datetime.now()
t2 = datetime.datetime.now()
- return [ProtocolUnitResult(source_key=u.key, inputs=u.inputs,
- outputs={'result': i ** 2},
- start_time=t1, end_time=t2)
- for i, u in enumerate(units)]
+ return [
+ ProtocolUnitResult(
+ source_key=u.key,
+ inputs=u.inputs,
+ outputs={"result": i**2},
+ start_time=t1,
+ end_time=t2,
+ )
+ for i, u in enumerate(units)
+ ]
@staticmethod
@pytest.fixture()
@@ -605,13 +650,21 @@ def failures(units) -> list[list[ProtocolUnitFailure]]:
t1 = datetime.datetime.now()
t2 = datetime.datetime.now()
- return [[ProtocolUnitFailure(source_key=u.key, inputs=u.inputs,
- outputs=dict(),
- exception=('ValueError', "Didn't feel like it"),
- traceback='foo',
- start_time=t1, end_time=t2)
- for i in range(2)]
- for u in units]
+ return [
+ [
+ ProtocolUnitFailure(
+ source_key=u.key,
+ inputs=u.inputs,
+ outputs=dict(),
+ exception=("ValueError", "Didn't feel like it"),
+ traceback="foo",
+ start_time=t1,
+ end_time=t2,
+ )
+ for i in range(2)
+ ]
+ for u in units
+ ]
def test_all_successes(self, units, successes):
dagresult = ProtocolDAGResult(
@@ -679,22 +732,26 @@ def test_foreign_objects(self, units, successes):
def test_execute_DAG_retries(solvated_ligand, vacuum_ligand, tmpdir):
protocol = BrokenProtocol(settings=BrokenProtocol.default_settings())
dag = protocol.create(
- stateA=solvated_ligand, stateB=vacuum_ligand, mapping=None,
+ stateA=solvated_ligand,
+ stateB=vacuum_ligand,
+ mapping=None,
)
with tmpdir.as_cwd():
- shared = pathlib.Path('shared')
+ shared = pathlib.Path("shared")
shared.mkdir(parents=True)
- scratch = pathlib.Path('scratch')
+ scratch = pathlib.Path("scratch")
scratch.mkdir(parents=True)
- r = execute_DAG(dag,
- shared_basedir=shared,
- scratch_basedir=scratch,
- keep_shared=True,
- keep_scratch=True,
- raise_error=False,
- n_retries=3)
+ r = execute_DAG(
+ dag,
+ shared_basedir=shared,
+ scratch_basedir=scratch,
+ keep_shared=True,
+ keep_scratch=True,
+ raise_error=False,
+ n_retries=3,
+ )
assert not r.ok()
@@ -708,26 +765,31 @@ def test_execute_DAG_retries(solvated_ligand, vacuum_ligand, tmpdir):
# final failure
assert number_unit_results == number_dirs == 26
+
def test_execute_DAG_bad_nretries(solvated_ligand, vacuum_ligand, tmpdir):
protocol = BrokenProtocol(settings=BrokenProtocol.default_settings())
dag = protocol.create(
- stateA=solvated_ligand, stateB=vacuum_ligand, mapping=None,
+ stateA=solvated_ligand,
+ stateB=vacuum_ligand,
+ mapping=None,
)
with tmpdir.as_cwd():
- shared = pathlib.Path('shared')
+ shared = pathlib.Path("shared")
shared.mkdir(parents=True)
- scratch = pathlib.Path('scratch')
+ scratch = pathlib.Path("scratch")
scratch.mkdir(parents=True)
with pytest.raises(ValueError):
- r = execute_DAG(dag,
- shared_basedir=shared,
- scratch_basedir=scratch,
- keep_shared=True,
- keep_scratch=True,
- raise_error=False,
- n_retries=-1)
+ r = execute_DAG(
+ dag,
+ shared_basedir=shared,
+ scratch_basedir=scratch,
+ keep_shared=True,
+ keep_scratch=True,
+ raise_error=False,
+ n_retries=-1,
+ )
def test_settings_readonly():
diff --git a/gufe/tests/test_protocoldag.py b/gufe/tests/test_protocoldag.py
index 9454eda0..5c898492 100644
--- a/gufe/tests/test_protocoldag.py
+++ b/gufe/tests/test_protocoldag.py
@@ -2,6 +2,7 @@
# For details, see https://github.com/OpenFreeEnergy/gufe
import os
import pathlib
+
import pytest
from openff.units import unit
@@ -12,15 +13,15 @@
class WriterUnit(gufe.ProtocolUnit):
@staticmethod
def _execute(ctx, **inputs):
- my_id = inputs['identity']
+ my_id = inputs["identity"]
- with open(os.path.join(ctx.shared, f'unit_{my_id}_shared.txt'), 'w') as out:
- out.write(f'unit {my_id} existed!\n')
- with open(os.path.join(ctx.scratch, f'unit_{my_id}_scratch.txt'), 'w') as out:
- out.write(f'unit {my_id} was here\n')
+ with open(os.path.join(ctx.shared, f"unit_{my_id}_shared.txt"), "w") as out:
+ out.write(f"unit {my_id} existed!\n")
+ with open(os.path.join(ctx.scratch, f"unit_{my_id}_scratch.txt"), "w") as out:
+ out.write(f"unit {my_id} was here\n")
return {
- 'log': 'finished',
+ "log": "finished",
}
@@ -30,11 +31,9 @@ class WriterSettings(gufe.settings.Settings):
class WriterProtocolResult(gufe.ProtocolResult):
- def get_estimate(self):
- ...
+ def get_estimate(self): ...
- def get_uncertainty(self):
- ...
+ def get_uncertainty(self): ...
class WriterProtocol(gufe.Protocol):
@@ -45,18 +44,15 @@ def _default_settings(cls):
return WriterSettings(
thermo_settings=gufe.settings.ThermoSettings(temperature=298 * unit.kelvin),
forcefield_settings=gufe.settings.OpenMMSystemGeneratorFFSettings(),
- n_repeats=4
+ n_repeats=4,
)
-
@classmethod
def _defaults(cls):
return {}
- def _create(self, stateA, stateB, mapping, extends=None) -> list[gufe.ProtocolUnit]:
- return [
- WriterUnit(identity=i) for i in range(self.settings.n_repeats) # type: ignore
- ]
+ def _create(self, stateA, stateB, mapping, extends=None) -> list[gufe.ProtocolUnit]:
+ return [WriterUnit(identity=i) for i in range(self.settings.n_repeats)] # type: ignore
def _gather(self, results):
return {}
@@ -72,34 +68,36 @@ def writefile_dag():
return p.create(stateA=s1, stateB=s2, mapping=[])
-@pytest.mark.parametrize('keep_shared', [False, True])
-@pytest.mark.parametrize('keep_scratch', [False, True])
+@pytest.mark.parametrize("keep_shared", [False, True])
+@pytest.mark.parametrize("keep_scratch", [False, True])
def test_execute_dag(tmpdir, keep_shared, keep_scratch, writefile_dag):
with tmpdir.as_cwd():
- shared = pathlib.Path('shared')
+ shared = pathlib.Path("shared")
shared.mkdir(parents=True)
- scratch = pathlib.Path('scratch')
+ scratch = pathlib.Path("scratch")
scratch.mkdir(parents=True)
-
+
# run dag
- execute_DAG(writefile_dag,
- shared_basedir=shared,
- scratch_basedir=scratch,
- keep_shared=keep_shared,
- keep_scratch=keep_scratch)
-
+ execute_DAG(
+ writefile_dag,
+ shared_basedir=shared,
+ scratch_basedir=scratch,
+ keep_shared=keep_shared,
+ keep_scratch=keep_scratch,
+ )
+
# check outputs are as expected
# will have produced 4 files in scratch and shared directory
for pu in writefile_dag.protocol_units:
- identity = pu.inputs['identity']
- shared_file = os.path.join(shared,
- f'shared_{str(pu.key)}_attempt_0',
- f'unit_{identity}_shared.txt')
- scratch_file = os.path.join(scratch,
- f'scratch_{str(pu.key)}_attempt_0',
- f'unit_{identity}_scratch.txt')
+ identity = pu.inputs["identity"]
+ shared_file = os.path.join(shared, f"shared_{str(pu.key)}_attempt_0", f"unit_{identity}_shared.txt")
+ scratch_file = os.path.join(
+ scratch,
+ f"scratch_{str(pu.key)}_attempt_0",
+ f"unit_{identity}_scratch.txt",
+ )
if keep_shared:
assert os.path.exists(shared_file)
else:
diff --git a/gufe/tests/test_protocolresult.py b/gufe/tests/test_protocolresult.py
index 4746906b..86ee0c1c 100644
--- a/gufe/tests/test_protocolresult.py
+++ b/gufe/tests/test_protocolresult.py
@@ -1,18 +1,19 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
-from openff.units import unit
import pytest
+from openff.units import unit
import gufe
+
from .test_tokenization import GufeTokenizableTestsMixin
class DummyProtocolResult(gufe.ProtocolResult):
def get_estimate(self):
- return self.data['estimate']
+ return self.data["estimate"]
def get_uncertainty(self):
- return self.data['uncertainty']
+ return self.data["uncertainty"]
class TestProtocolResult(GufeTokenizableTestsMixin):
@@ -25,3 +26,31 @@ def instance(self):
estimate=4.2 * unit.kilojoule_per_mole,
uncertainty=0.2 * unit.kilojoule_per_mole,
)
+
+ def test_protocolresult_get_estimate(self, instance):
+ assert instance.get_estimate() == 4.2 * unit.kilojoule_per_mole
+
+ def test_protocolresult_get_uncertainty(self, instance):
+ assert instance.get_uncertainty() == 0.2 * unit.kilojoule_per_mole
+
+ def test_protocolresult_default_n_protocol_dag_results(self, instance):
+ assert instance.n_protocol_dag_results == 0
+
+ def test_protocol_result_from_dict_missing_n_protocol_dag_results(self, instance):
+ protocol_result_dict_form = instance.to_dict()
+ assert DummyProtocolResult.from_dict(protocol_result_dict_form) == instance
+ del protocol_result_dict_form["n_protocol_dag_results"]
+ assert DummyProtocolResult.from_dict(protocol_result_dict_form) == instance
+
+ @pytest.mark.parametrize("arg, expected", [(0, 0), (1, 1), (-1, ValueError)])
+ def test_protocolresult_get_n_protocol_dag_results_args(self, arg, expected):
+ try:
+ protocol_result = DummyProtocolResult(
+ n_protocol_dag_results=arg,
+ estimate=4.2 * unit.kilojoule_per_mole,
+ uncertainty=0.2 * unit.kilojoule_per_mole,
+ )
+ assert protocol_result.n_protocol_dag_results == expected
+ except ValueError:
+ if expected is not ValueError:
+ raise AssertionError()
diff --git a/gufe/tests/test_protocolunit.py b/gufe/tests/test_protocolunit.py
index 9896f856..51a37a47 100644
--- a/gufe/tests/test_protocolunit.py
+++ b/gufe/tests/test_protocolunit.py
@@ -1,8 +1,9 @@
import string
-import pytest
from pathlib import Path
-from gufe.protocols.protocolunit import ProtocolUnit, Context, ProtocolUnitFailure, ProtocolUnitResult
+import pytest
+
+from gufe.protocols.protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
@@ -51,27 +52,27 @@ def test_execute(self, tmpdir):
unit = DummyUnit()
- shared = Path('shared') / str(unit.key)
+ shared = Path("shared") / str(unit.key)
shared.mkdir(parents=True)
- scratch = Path('scratch') / str(unit.key)
+ scratch = Path("scratch") / str(unit.key)
scratch.mkdir(parents=True)
ctx = Context(shared=shared, scratch=scratch)
-
+
u: ProtocolUnitFailure = unit.execute(context=ctx, an_input=3)
assert u.exception[0] == "ValueError"
unit = DummyUnit()
- shared = Path('shared') / str(unit.key)
+ shared = Path("shared") / str(unit.key)
shared.mkdir(parents=True)
- scratch = Path('scratch') / str(unit.key)
+ scratch = Path("scratch") / str(unit.key)
scratch.mkdir(parents=True)
ctx = Context(shared=shared, scratch=scratch)
-
+
# now try actually letting the error raise on execute
with pytest.raises(ValueError, match="should always be 2"):
unit.execute(context=ctx, raise_error=True, an_input=3)
@@ -81,23 +82,23 @@ def test_execute_KeyboardInterrupt(self, tmpdir):
unit = DummyKeyboardInterruptUnit()
- shared = Path('shared') / str(unit.key)
+ shared = Path("shared") / str(unit.key)
shared.mkdir(parents=True)
- scratch = Path('scratch') / str(unit.key)
+ scratch = Path("scratch") / str(unit.key)
scratch.mkdir(parents=True)
ctx = Context(shared=shared, scratch=scratch)
-
+
with pytest.raises(KeyboardInterrupt):
unit.execute(context=ctx, an_input=3)
-
+
u: ProtocolUnitResult = unit.execute(context=ctx, an_input=2)
- assert u.outputs == {'foo': 'bar'}
+ assert u.outputs == {"foo": "bar"}
def test_normalize(self, dummy_unit):
thingy = dummy_unit.key
- assert thingy.startswith('DummyUnit-')
- assert all(t in string.hexdigits for t in thingy.partition('-')[-1])
+ assert thingy.startswith("DummyUnit-")
+ assert all(t in string.hexdigits for t in thingy.partition("-")[-1])
diff --git a/gufe/tests/test_serialization_migration.py b/gufe/tests/test_serialization_migration.py
index 3e664252..69d71d4f 100644
--- a/gufe/tests/test_serialization_migration.py
+++ b/gufe/tests/test_serialization_migration.py
@@ -1,47 +1,51 @@
-import pytest
import copy
+from typing import Any, Optional, Type
+import pytest
+from pydantic import BaseModel
+
+from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
from gufe.tokenization import (
GufeTokenizable,
- new_key_added,
- old_key_removed,
- key_renamed,
- nested_key_moved,
- from_dict,
_label_to_parts,
_pop_nested,
_set_nested,
+ from_dict,
+ key_renamed,
+ nested_key_moved,
+ new_key_added,
+ old_key_removed,
)
-from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
-from pydantic import BaseModel
-
-from typing import Optional, Any, Type
-
@pytest.fixture
def nested_data():
- return {
- "foo": {"foo2" : [{"foo3": "foo4"}, "foo5"]},
- "bar": ["bar2", "bar3"]
- }
-
-@pytest.mark.parametrize('label, expected', [
- ("foo", ["foo"]),
- ("foo.foo2", ["foo", "foo2"]),
- ("foo.foo2[0]", ["foo", "foo2", 0]),
- ("foo.foo2[0].foo3", ["foo", "foo2", 0, "foo3"]),
-])
+ return {"foo": {"foo2": [{"foo3": "foo4"}, "foo5"]}, "bar": ["bar2", "bar3"]}
+
+
+@pytest.mark.parametrize(
+ "label, expected",
+ [
+ ("foo", ["foo"]),
+ ("foo.foo2", ["foo", "foo2"]),
+ ("foo.foo2[0]", ["foo", "foo2", 0]),
+ ("foo.foo2[0].foo3", ["foo", "foo2", 0, "foo3"]),
+ ],
+)
def test_label_to_parts(label, expected):
assert _label_to_parts(label) == expected
-@pytest.mark.parametrize('label, popped, remaining', [
- ("foo", {"foo2" : [{"foo3": "foo4"}, "foo5"]}, {}),
- ("foo.foo2", [{"foo3": "foo4"}, "foo5"], {"foo": {}}),
- ("foo.foo2[0]", {"foo3": "foo4"}, {"foo": {"foo2": ["foo5"]}}),
- ("foo.foo2[0].foo3", "foo4", {"foo": {"foo2": [{}, "foo5"]}}),
- ("foo.foo2[1]", "foo5", {"foo": {"foo2": [{"foo3": "foo4"}]}}),
-])
+
+@pytest.mark.parametrize(
+ "label, popped, remaining",
+ [
+ ("foo", {"foo2": [{"foo3": "foo4"}, "foo5"]}, {}),
+ ("foo.foo2", [{"foo3": "foo4"}, "foo5"], {"foo": {}}),
+ ("foo.foo2[0]", {"foo3": "foo4"}, {"foo": {"foo2": ["foo5"]}}),
+ ("foo.foo2[0].foo3", "foo4", {"foo": {"foo2": [{}, "foo5"]}}),
+ ("foo.foo2[1]", "foo5", {"foo": {"foo2": [{"foo3": "foo4"}]}}),
+ ],
+)
def test_pop_nested(nested_data, label, popped, remaining):
val = _pop_nested(nested_data, label)
expected_remaining = {"bar": ["bar2", "bar3"]}
@@ -49,13 +53,17 @@ def test_pop_nested(nested_data, label, popped, remaining):
assert val == popped
assert nested_data == expected_remaining
-@pytest.mark.parametrize("label, expected_foo", [
- ("foo", {"foo": 10}),
- ("foo.foo2", {"foo": {"foo2": 10}}),
- ("foo.foo2[0]", {"foo": {"foo2": [10, "foo5"]}}),
- ("foo.foo2[0].foo3", {"foo": {"foo2": [{"foo3": 10}, "foo5"]}}),
- ("foo.foo2[1]", {"foo": {"foo2": [{"foo3": "foo4"}, 10]}}),
-])
+
+@pytest.mark.parametrize(
+ "label, expected_foo",
+ [
+ ("foo", {"foo": 10}),
+ ("foo.foo2", {"foo": {"foo2": 10}}),
+ ("foo.foo2[0]", {"foo": {"foo2": [10, "foo5"]}}),
+ ("foo.foo2[0].foo3", {"foo": {"foo2": [{"foo3": 10}, "foo5"]}}),
+ ("foo.foo2[1]", {"foo": {"foo2": [{"foo3": "foo4"}, 10]}}),
+ ],
+)
def test_set_nested(nested_data, label, expected_foo):
_set_nested(nested_data, label, 10)
expected = {"bar": ["bar2", "bar3"]}
@@ -65,6 +73,7 @@ def test_set_nested(nested_data, label, expected_foo):
class _DefaultBase(GufeTokenizable):
"""Convenience class to avoid rewriting these methods"""
+
@classmethod
def _from_dict(cls, dct):
return cls(**dct)
@@ -80,16 +89,17 @@ def _schema_version(cls):
# this represents an "original" object with fields `foo` and `bar`
_SERIALIZED_OLD = {
- '__module__': None, # define in each test
- '__qualname__': None, # define in each test
- 'foo': "foo",
- 'bar': "bar",
- ':version:': 1,
+ "__module__": None, # define in each test
+ "__qualname__": None, # define in each test
+ "foo": "foo",
+ "bar": "bar",
+ ":version:": 1,
}
class KeyAdded(_DefaultBase):
"""Add key ``qux`` to the object's dict"""
+
def __init__(self, foo, bar, qux=10):
self.foo = foo
self.bar = bar
@@ -98,7 +108,7 @@ def __init__(self, foo, bar, qux=10):
@classmethod
def serialization_migration(cls, dct, version):
if version == 1:
- dct = new_key_added(dct, 'qux', 10)
+ dct = new_key_added(dct, "qux", 10)
return dct
@@ -108,6 +118,7 @@ def _to_dict(self):
class KeyRemoved(_DefaultBase):
"""Remove key ``bar`` from the object's dict"""
+
def __init__(self, foo):
self.foo = foo
@@ -124,6 +135,7 @@ def _to_dict(self):
class KeyRenamed(_DefaultBase):
"""Rename key ``bar`` to ``baz`` in the object's dict"""
+
def __init__(self, foo, baz):
self.foo = foo
self.baz = baz
@@ -153,17 +165,17 @@ def instance(self):
def _prep_dct(self, dct):
dct = copy.deepcopy(self.input_dict)
- dct['__module__'] = self.cls.__module__
- dct['__qualname__'] = self.cls.__qualname__
+ dct["__module__"] = self.cls.__module__
+ dct["__qualname__"] = self.cls.__qualname__
return dct
def test_serialization_migration(self):
# in these examples, self.kwargs is the same as the output of
# serialization_migration (not necessarily true for all classes)
dct = self._prep_dct(self.input_dict)
- del dct['__module__']
- del dct['__qualname__']
- version = dct.pop(':version:')
+ del dct["__module__"]
+ del dct["__qualname__"]
+ version = dct.pop(":version:")
assert self.cls.serialization_migration(dct, version) == self.kwargs
def test_migration(self, instance):
@@ -172,6 +184,7 @@ def test_migration(self, instance):
expected = instance
assert expected == reconstructed
+
class TestKeyAdded(MigrationTester):
cls = KeyAdded
input_dict = _SERIALIZED_OLD
@@ -196,12 +209,7 @@ class TestKeyRenamed(MigrationTester):
"__module__": ...,
"__qualname__": ...,
":version:": 1,
- "settings": {
- "son": {
- "son_child": 10
- },
- "daughter": {}
- }
+ "settings": {"son": {"son_child": 10}, "daughter": {}},
}
@@ -211,6 +219,7 @@ class SonSettings(BaseModel):
class DaughterSettings(BaseModel):
"""v2 model has child; v1 would not"""
+
daughter_child: int
@@ -224,11 +233,11 @@ def __init__(self, settings: GrandparentSettings):
self.settings = settings
def _to_dict(self):
- return {'settings': self.settings.dict()}
+ return {"settings": self.settings.dict()}
@classmethod
def _from_dict(cls, dct):
- settings = GrandparentSettings.parse_obj(dct['settings'])
+ settings = GrandparentSettings.parse_obj(dct["settings"])
return cls(settings=settings)
@classmethod
@@ -241,7 +250,7 @@ def serialization_migration(cls, dct, version):
dct = nested_key_moved(
dct,
old_name="settings.son.son_child",
- new_name="settings.daughter.daughter_child"
+ new_name="settings.daughter.daughter_child",
)
return dct
@@ -250,13 +259,8 @@ def serialization_migration(cls, dct, version):
class TestNestedKeyMoved(MigrationTester):
cls = Grandparent
input_dict = _SERIALIZED_NESTED_OLD
- kwargs = {
- 'settings': {'son': {}, 'daughter': {'daughter_child': 10}}
- }
+ kwargs = {"settings": {"son": {}, "daughter": {"daughter_child": 10}}}
@pytest.fixture
def instance(self):
- return self.cls(GrandparentSettings(
- son=SonSettings(),
- daughter=DaughterSettings(daughter_child=10)
- ))
+ return self.cls(GrandparentSettings(son=SonSettings(), daughter=DaughterSettings(daughter_child=10)))
diff --git a/gufe/tests/test_smallmoleculecomponent.py b/gufe/tests/test_smallmoleculecomponent.py
index 354771c9..e224227d 100644
--- a/gufe/tests/test_smallmoleculecomponent.py
+++ b/gufe/tests/test_smallmoleculecomponent.py
@@ -3,6 +3,7 @@
import importlib
import importlib.resources
+
try:
import openff.toolkit.topology
from openff.units import unit
@@ -10,54 +11,58 @@
HAS_OFFTK = False
else:
HAS_OFFTK = True
+import json
import os
from unittest import mock
-import pytest
-from gufe import SmallMoleculeComponent
-from gufe.components.explicitmoleculecomponent import (
- _ensure_ofe_name,
-)
-import gufe
-import json
+import pytest
from rdkit import Chem
from rdkit.Chem import AllChem
+
+import gufe
+from gufe import SmallMoleculeComponent
+from gufe.components.explicitmoleculecomponent import _ensure_ofe_name
from gufe.tokenization import TOKENIZABLE_REGISTRY
from .test_tokenization import GufeTokenizableTestsMixin
+
@pytest.fixture
def alt_ethane():
mol = Chem.AddHs(Chem.MolFromSmiles("CC"))
Chem.AllChem.Compute2DCoords(mol)
return SmallMoleculeComponent(mol)
+
@pytest.fixture
def named_ethane():
mol = Chem.AddHs(Chem.MolFromSmiles("CC"))
Chem.AllChem.Compute2DCoords(mol)
- return SmallMoleculeComponent(mol, name='ethane')
-
-
-@pytest.mark.parametrize('internal,rdkit_name,name,expected', [
- ('', 'foo', '', 'foo'),
- ('', '', 'foo', 'foo'),
- ('', 'bar', 'foo', 'foo'),
- ('bar', '', 'foo', 'foo'),
- ('baz', 'bar', 'foo', 'foo'),
- ('foo', '', '', 'foo'),
-])
+ return SmallMoleculeComponent(mol, name="ethane")
+
+
+@pytest.mark.parametrize(
+ "internal,rdkit_name,name,expected",
+ [
+ ("", "foo", "", "foo"),
+ ("", "", "foo", "foo"),
+ ("", "bar", "foo", "foo"),
+ ("bar", "", "foo", "foo"),
+ ("baz", "bar", "foo", "foo"),
+ ("foo", "", "", "foo"),
+ ],
+)
def test_ensure_ofe_name(internal, rdkit_name, name, expected, recwarn):
rdkit = Chem.AddHs(Chem.MolFromSmiles("CC"))
if internal:
- rdkit.SetProp('_Name', internal)
+ rdkit.SetProp("_Name", internal)
if rdkit_name:
- rdkit.SetProp('ofe-name', rdkit_name)
+ rdkit.SetProp("ofe-name", rdkit_name)
out_name = _ensure_ofe_name(rdkit, name)
- if {rdkit_name, internal} - {'foo', ''}:
+ if {rdkit_name, internal} - {"foo", ""}:
# we should warn if rdkit properties are anything other than 'foo'
# (expected) or the empty string (not set)
assert len(recwarn) == 1
@@ -91,22 +96,22 @@ def test_warn_multiple_conformers(self):
def test_rdkit_independence(self):
# once we've constructed a Molecule, it is independent from the source
- mol = Chem.MolFromSmiles('CC')
+ mol = Chem.MolFromSmiles("CC")
AllChem.Compute2DCoords(mol)
our_mol = SmallMoleculeComponent.from_rdkit(mol)
- mol.SetProp('foo', 'bar') # this is the source molecule, not ours
+ mol.SetProp("foo", "bar") # this is the source molecule, not ours
with pytest.raises(KeyError):
- our_mol.to_rdkit().GetProp('foo')
+ our_mol.to_rdkit().GetProp("foo")
def test_rdkit_copy_source_copy(self):
# we should copy in any properties that were in the source molecule
- mol = Chem.MolFromSmiles('CC')
+ mol = Chem.MolFromSmiles("CC")
AllChem.Compute2DCoords(mol)
- mol.SetProp('foo', 'bar')
+ mol.SetProp("foo", "bar")
our_mol = SmallMoleculeComponent.from_rdkit(mol)
- assert our_mol.to_rdkit().GetProp('foo') == 'bar'
+ assert our_mol.to_rdkit().GetProp("foo") == "bar"
def test_equality_and_hash(self, ethane, alt_ethane):
assert hash(ethane) == hash(alt_ethane)
@@ -118,13 +123,13 @@ def test_equality_and_hash_name_differs(self, ethane, named_ethane):
assert ethane != named_ethane
def test_smiles(self, named_ethane):
- assert named_ethane.smiles == 'CC'
+ assert named_ethane.smiles == "CC"
def test_name(self, named_ethane):
- assert named_ethane.name == 'ethane'
+ assert named_ethane.name == "ethane"
def test_empty_name(self, alt_ethane):
- assert alt_ethane.name == ''
+ assert alt_ethane.name == ""
@pytest.mark.xfail
def test_serialization_cycle(self, named_ethane):
@@ -136,24 +141,23 @@ def test_serialization_cycle(self, named_ethane):
assert serialized == reserialized
def test_to_sdf_string(self, named_ethane, ethane_sdf):
- with open(ethane_sdf, "r") as f:
+ with open(ethane_sdf) as f:
expected = f.read()
assert named_ethane.to_sdf() == expected
@pytest.mark.xfail
def test_from_sdf_string(self, named_ethane, ethane_sdf):
- with open(ethane_sdf, "r") as f:
+ with open(ethane_sdf) as f:
sdf_str = f.read()
assert SmallMoleculeComponent.from_sdf_string(sdf_str) == named_ethane
@pytest.mark.xfail
- def test_from_sdf_file(self, named_ethane, ethane_sdf,
- tmpdir):
- with open(ethane_sdf, 'r') as f:
+ def test_from_sdf_file(self, named_ethane, ethane_sdf, tmpdir):
+ with open(ethane_sdf) as f:
sdf_str = f.read()
- with open(tmpdir / "temp.sdf", mode='w') as tmpf:
+ with open(tmpdir / "temp.sdf", mode="w") as tmpf:
tmpf.write(sdf_str)
assert SmallMoleculeComponent.from_sdf_file(tmpdir / "temp.sdf") == named_ethane
@@ -163,7 +167,7 @@ def test_from_sdf_file_junk(self, toluene_mol2_path):
SmallMoleculeComponent.from_sdf_file(toluene_mol2_path)
def test_from_sdf_string_multiple_molecules(self, multi_molecule_sdf):
- data = open(multi_molecule_sdf, 'r').read()
+ data = open(multi_molecule_sdf).read()
with pytest.raises(RuntimeError, match="contains more than 1"):
SmallMoleculeComponent.from_sdf_string(data)
@@ -184,17 +188,20 @@ def test_serialization_cycle_smiles(self, named_ethane):
assert named_ethane is not copy
assert named_ethane.smiles == copy.smiles
- @pytest.mark.parametrize('replace', (
- ['name'],
- ['mol'],
- ['name', 'mol'],
- ))
+ @pytest.mark.parametrize(
+ "replace",
+ (
+ ["name"],
+ ["mol"],
+ ["name", "mol"],
+ ),
+ )
def test_copy_with_replacements(self, named_ethane, replace):
replacements = {}
- if 'name' in replace:
- replacements['name'] = "foo"
+ if "name" in replace:
+ replacements["name"] = "foo"
- if 'mol' in replace:
+ if "mol" in replace:
# it is a little weird to use copy_with_replacements to replace
# the whole molecule (possibly keeping the same name), but it
# should work if someone does! (could more easily imagine only
@@ -203,16 +210,16 @@ def test_copy_with_replacements(self, named_ethane, replace):
Chem.AllChem.Compute2DCoords(rdmol)
mol = SmallMoleculeComponent.from_rdkit(rdmol)
dct = mol._to_dict()
- for item in ['atoms', 'bonds', 'conformer']:
+ for item in ["atoms", "bonds", "conformer"]:
replacements[item] = dct[item]
new = named_ethane.copy_with_replacements(**replacements)
- if 'name' in replace:
+ if "name" in replace:
assert new.name == "foo"
else:
assert new.name == "ethane"
- if 'mol' in replace:
+ if "mol" in replace:
assert new.smiles == "CO"
else:
assert new.smiles == "CC"
@@ -228,15 +235,15 @@ def test_to_off(self, ethane):
def test_to_off_name(self, named_ethane):
off_ethane = named_ethane.to_openff()
- assert off_ethane.name == 'ethane'
+ assert off_ethane.name == "ethane"
@pytest.mark.skipif(not HAS_OFFTK, reason="no openff tookit available")
class TestSmallMoleculeComponentPartialCharges:
- @pytest.fixture(scope='function')
+ @pytest.fixture(scope="function")
def charged_off_ethane(self, ethane):
off_ethane = ethane.to_openff()
- off_ethane.assign_partial_charges(partial_charge_method='am1bcc')
+ off_ethane.assign_partial_charges(partial_charge_method="am1bcc")
return off_ethane
def test_partial_charges_warning(self, charged_off_ethane):
@@ -258,7 +265,7 @@ def test_partial_charges_not_formal_error(self, charged_off_ethane):
def test_partial_charges_too_few_atoms(self):
mol = Chem.AddHs(Chem.MolFromSmiles("CC"))
Chem.AllChem.Compute2DCoords(mol)
- mol.SetProp('atom.dprop.PartialCharge', '1')
+ mol.SetProp("atom.dprop.PartialCharge", "1")
with pytest.raises(ValueError, match="Incorrect number of"):
SmallMoleculeComponent.from_rdkit(mol)
@@ -271,7 +278,7 @@ def test_partial_charges_applied_to_atoms(self):
mol = Chem.AddHs(Chem.MolFromSmiles("C"))
Chem.AllChem.Compute2DCoords(mol)
# add some fake charges at the molecule level
- mol.SetProp('atom.dprop.PartialCharge', '-1 0.25 0.25 0.25 0.25')
+ mol.SetProp("atom.dprop.PartialCharge", "-1 0.25 0.25 0.25 0.25")
matchmsg = "Partial charges have been provided"
with pytest.warns(UserWarning, match=matchmsg):
ofe = SmallMoleculeComponent.from_rdkit(mol)
@@ -292,22 +299,24 @@ def test_inconsistent_charges(self, charged_off_ethane):
mol = Chem.AddHs(Chem.MolFromSmiles("C"))
Chem.AllChem.Compute2DCoords(mol)
# add some fake charges at the molecule level
- mol.SetProp('atom.dprop.PartialCharge', '-1 0.25 0.25 0.25 0.25')
+ mol.SetProp("atom.dprop.PartialCharge", "-1 0.25 0.25 0.25 0.25")
# set different charges to the atoms
for atom in mol.GetAtoms():
atom.SetDoubleProp("PartialCharge", 0)
# make sure the correct error is raised
- msg = ("non-equivalent partial charges between "
- "atom and molecule properties")
+ msg = "non-equivalent partial charges between " "atom and molecule properties"
with pytest.raises(ValueError, match=msg):
SmallMoleculeComponent.from_rdkit(mol)
-
-@pytest.mark.parametrize('mol, charge', [
- ('CC', 0), ('CC[O-]', -1),
-])
+@pytest.mark.parametrize(
+ "mol, charge",
+ [
+ ("CC", 0),
+ ("CC[O-]", -1),
+ ],
+)
def test_total_charge_neutral(mol, charge):
mol = Chem.MolFromSmiles(mol)
AllChem.Compute2DCoords(mol)
@@ -330,6 +339,36 @@ def test_to_dict(self, phenol):
assert isinstance(d, dict)
+ def test_to_dict_hybridization(self, phenol):
+ """
+ Make sure dict round trip saves the hybridization
+
+ """
+ phenol_dict = phenol.to_dict()
+ TOKENIZABLE_REGISTRY.clear()
+ new_phenol = SmallMoleculeComponent.from_dict(phenol_dict)
+ for atom in new_phenol.to_rdkit().GetAtoms():
+ if atom.GetAtomicNum() == 6:
+ assert atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2
+
+ def test_from_dict_missing_hybridization(self, phenol):
+ """
+ For backwards compatibility make sure we can create an SMC with missing hybridization info.
+ """
+ phenol_dict = phenol.to_dict()
+ new_atoms = []
+ for atom in phenol_dict["atoms"]:
+ # remove the hybridization atomic info which should be at index 7
+ new_atoms.append(tuple([atom_info for i, atom_info in enumerate(atom) if i != 7]))
+ phenol_dict["atoms"] = new_atoms
+ with pytest.warns(match="The atom hybridization data was not found and has been set to unspecified."):
+ new_phenol = SmallMoleculeComponent.from_dict(phenol_dict)
+ # they should be different objects due to the missing hybridization info
+ assert new_phenol != phenol
+ # make sure the rdkit objects are different
+ for atom_hybrid, atom_no_hybrid in zip(phenol.to_rdkit().GetAtoms(), new_phenol.to_rdkit().GetAtoms()):
+ assert atom_hybrid.GetHybridization() != atom_no_hybrid.GetHybridization()
+
@pytest.mark.skipif(not HAS_OFFTK, reason="no openff toolkit available")
def test_deserialize_roundtrip(self, toluene, phenol):
roundtrip = SmallMoleculeComponent.from_dict(phenol.to_dict())
@@ -347,11 +386,11 @@ def test_deserialize_roundtrip(self, toluene, phenol):
@pytest.mark.xfail
def test_bounce_off_file(self, toluene, tmpdir):
- fname = str(tmpdir / 'mol.json')
+ fname = str(tmpdir / "mol.json")
- with open(fname, 'w') as f:
+ with open(fname, "w") as f:
f.write(toluene.to_json())
- with open(fname, 'r') as f:
+ with open(fname) as f:
d = json.load(f)
assert isinstance(d, dict)
@@ -373,29 +412,29 @@ def test_to_openff_after_serialisation(self, toluene):
assert off1 == off2
-@pytest.mark.parametrize('target', ['atom', 'bond', 'conformer', 'mol'])
-@pytest.mark.parametrize('dtype', ['int', 'bool', 'str', 'float'])
+@pytest.mark.parametrize("target", ["atom", "bond", "conformer", "mol"])
+@pytest.mark.parametrize("dtype", ["int", "bool", "str", "float"])
def test_prop_preservation(ethane, target, dtype):
# issue 145 make sure props are propagated
mol = Chem.MolFromSmiles("CC")
Chem.AllChem.Compute2DCoords(mol)
- if target == 'atom':
+ if target == "atom":
obj = mol.GetAtomWithIdx(0)
- elif target == 'bond':
+ elif target == "bond":
obj = mol.GetBondWithIdx(0)
- elif target == 'conformer':
+ elif target == "conformer":
obj = mol.GetConformer()
else:
obj = mol
- if dtype == 'int':
- obj.SetIntProp('foo', 1234)
- elif dtype == 'bool':
- obj.SetBoolProp('foo', False)
- elif dtype == 'str':
- obj.SetProp('foo', 'bar')
- elif dtype == 'float':
- obj.SetDoubleProp('foo', 1.234)
+ if dtype == "int":
+ obj.SetIntProp("foo", 1234)
+ elif dtype == "bool":
+ obj.SetBoolProp("foo", False)
+ elif dtype == "str":
+ obj.SetProp("foo", "bar")
+ elif dtype == "float":
+ obj.SetDoubleProp("foo", 1.234)
else:
pytest.fail()
@@ -403,29 +442,29 @@ def test_prop_preservation(ethane, target, dtype):
d = SmallMoleculeComponent(rdkit=mol).to_dict()
e2 = SmallMoleculeComponent.from_dict(d).to_rdkit()
- if target == 'atom':
+ if target == "atom":
obj = e2.GetAtomWithIdx(0)
- elif target == 'bond':
+ elif target == "bond":
obj = e2.GetBondWithIdx(0)
- elif target == 'conformer':
+ elif target == "conformer":
obj = e2.GetConformer()
else:
obj = e2
- if dtype == 'int':
- assert obj.GetIntProp('foo') == 1234
- elif dtype == 'bool':
- assert obj.GetBoolProp('foo') is False
- elif dtype == 'str':
- assert obj.GetProp('foo') == 'bar'
+ if dtype == "int":
+ assert obj.GetIntProp("foo") == 1234
+ elif dtype == "bool":
+ assert obj.GetBoolProp("foo") is False
+ elif dtype == "str":
+ assert obj.GetProp("foo") == "bar"
else:
- assert obj.GetDoubleProp('foo') == pytest.approx(1.234)
+ assert obj.GetDoubleProp("foo") == pytest.approx(1.234)
def test_missing_H_warning():
- m = Chem.MolFromSmiles('CC')
+ m = Chem.MolFromSmiles("CC")
Chem.AllChem.Compute2DCoords(m)
- with pytest.warns(UserWarning, match='removeHs=False'):
+ with pytest.warns(UserWarning, match="removeHs=False"):
_ = SmallMoleculeComponent(rdkit=m)
diff --git a/gufe/tests/test_solvents.py b/gufe/tests/test_solvents.py
index 88b2a87b..e8530777 100644
--- a/gufe/tests/test_solvents.py
+++ b/gufe/tests/test_solvents.py
@@ -1,7 +1,7 @@
import pytest
+from openff.units import unit
from gufe import SolventComponent
-from openff.units import unit
from .test_tokenization import GufeTokenizableTestsMixin
@@ -9,70 +9,77 @@
def test_defaults():
s = SolventComponent()
- assert s.smiles == 'O'
+ assert s.smiles == "O"
assert s.positive_ion == "Na+"
assert s.negative_ion == "Cl-"
assert s.ion_concentration == 0.15 * unit.molar
-@pytest.mark.parametrize('pos, neg', [
- # test: charge dropping, case sensitivity
- ('Na', 'Cl'), ('Na+', 'Cl-'), ('na', 'cl'),
-])
+@pytest.mark.parametrize(
+ "pos, neg",
+ [
+ # test: charge dropping, case sensitivity
+ ("Na", "Cl"),
+ ("Na+", "Cl-"),
+ ("na", "cl"),
+ ],
+)
def test_hash(pos, neg):
- s1 = SolventComponent(positive_ion='Na', negative_ion='Cl')
+ s1 = SolventComponent(positive_ion="Na", negative_ion="Cl")
s2 = SolventComponent(positive_ion=pos, negative_ion=neg)
assert s1 == s2
assert hash(s1) == hash(s2)
- assert s2.positive_ion == 'Na+'
- assert s2.negative_ion == 'Cl-'
+ assert s2.positive_ion == "Na+"
+ assert s2.negative_ion == "Cl-"
def test_neq():
- s1 = SolventComponent(positive_ion='Na', negative_ion='Cl')
- s2 = SolventComponent(positive_ion='K', negative_ion='Cl')
+ s1 = SolventComponent(positive_ion="Na", negative_ion="Cl")
+ s2 = SolventComponent(positive_ion="K", negative_ion="Cl")
assert s1 != s2
-@pytest.mark.parametrize('conc', [0.0 * unit.molar, 1.75 * unit.molar])
+@pytest.mark.parametrize("conc", [0.0 * unit.molar, 1.75 * unit.molar])
def test_from_dict(conc):
- s1 = SolventComponent(positive_ion='Na', negative_ion='Cl',
- ion_concentration=conc,
- neutralize=False)
+ s1 = SolventComponent(positive_ion="Na", negative_ion="Cl", ion_concentration=conc, neutralize=False)
assert SolventComponent.from_dict(s1.to_dict()) == s1
def test_conc():
- s = SolventComponent(positive_ion='Na', negative_ion='Cl',
- ion_concentration=1.75 * unit.molar)
+ s = SolventComponent(positive_ion="Na", negative_ion="Cl", ion_concentration=1.75 * unit.molar)
- assert s.ion_concentration == unit.Quantity('1.75 M')
+ assert s.ion_concentration == unit.Quantity("1.75 M")
-@pytest.mark.parametrize('conc,',
- [1.22, # no units, 1.22 what?
- 1.5 * unit.kg, # probably a tad much salt
- -0.1 * unit.molar]) # negative conc
+@pytest.mark.parametrize(
+ "conc,",
+ [
+ 1.22, # no units, 1.22 what?
+ 1.5 * unit.kg, # probably a tad much salt
+ -0.1 * unit.molar,
+ ],
+) # negative conc
def test_bad_conc(conc):
with pytest.raises(ValueError):
- _ = SolventComponent(positive_ion='Na', negative_ion='Cl',
- ion_concentration=conc)
+ _ = SolventComponent(positive_ion="Na", negative_ion="Cl", ion_concentration=conc)
def test_solvent_charge():
- s = SolventComponent(positive_ion='Na', negative_ion='Cl',
- ion_concentration=1.75 * unit.molar)
+ s = SolventComponent(positive_ion="Na", negative_ion="Cl", ion_concentration=1.75 * unit.molar)
assert s.total_charge is None
-@pytest.mark.parametrize('pos, neg,', [
- ('Na', 'C'),
- ('F', 'I'),
-])
+@pytest.mark.parametrize(
+ "pos, neg,",
+ [
+ ("Na", "C"),
+ ("F", "I"),
+ ],
+)
def test_bad_inputs(pos, neg):
with pytest.raises(ValueError):
_ = SolventComponent(positive_ion=pos, negative_ion=neg)
@@ -85,4 +92,4 @@ class TestSolventComponent(GufeTokenizableTestsMixin):
@pytest.fixture
def instance(self):
- return SolventComponent(positive_ion='Na', negative_ion='Cl')
+ return SolventComponent(positive_ion="Na", negative_ion="Cl")
diff --git a/gufe/tests/test_tokenization.py b/gufe/tests/test_tokenization.py
index f7b0744d..fc2b49ce 100644
--- a/gufe/tests/test_tokenization.py
+++ b/gufe/tests/test_tokenization.py
@@ -1,17 +1,26 @@
-import pytest
import abc
import datetime
-import logging
import io
-from unittest import mock
import json
+import logging
from typing import Optional
+from unittest import mock
+
+import pytest
from gufe.tokenization import (
- GufeTokenizable, GufeKey, tokenize, TOKENIZABLE_REGISTRY,
- import_qualname, get_class, TOKENIZABLE_CLASS_REGISTRY, JSON_HANDLER,
- get_all_gufe_objs, gufe_to_digraph, gufe_objects_from_shallow_dict,
+ JSON_HANDLER,
+ TOKENIZABLE_CLASS_REGISTRY,
+ TOKENIZABLE_REGISTRY,
+ GufeKey,
+ GufeTokenizable,
KeyedChain,
+ get_all_gufe_objs,
+ get_class,
+ gufe_objects_from_shallow_dict,
+ gufe_to_digraph,
+ import_qualname,
+ tokenize,
)
@@ -65,7 +74,7 @@ def __init__(self, obj, lst, dct):
self.dct = dct
def _to_dict(self):
- return {'obj': self.obj, 'lst': self.lst, 'dct': self.dct}
+ return {"obj": self.obj, "lst": self.lst, "dct": self.dct}
@classmethod
def _from_dict(cls, dct):
@@ -87,9 +96,7 @@ class GufeTokenizableTestsMixin(abc.ABC):
@pytest.fixture
def instance(self):
- """Define instance to test with here.
-
- """
+ """Define instance to test with here."""
...
def test_to_dict_roundtrip(self, instance):
@@ -102,7 +109,7 @@ def test_to_dict_roundtrip(self, instance):
# not generally true that the dict forms are equal, e.g. if they
# include `np.nan`s
- #assert ser == reser
+ # assert ser == reser
@pytest.mark.skip
def test_to_dict_roundtrip_clear_registry(self, instance):
@@ -125,7 +132,7 @@ def test_to_keyed_dict_roundtrip(self, instance):
# not generally true that the dict forms are equal, e.g. if they
# include `np.nan`s
- #assert ser == reser
+ # assert ser == reser
def test_to_shallow_dict_roundtrip(self, instance):
ser = instance.to_shallow_dict()
@@ -137,7 +144,7 @@ def test_to_shallow_dict_roundtrip(self, instance):
# not generally true that the dict forms are equal, e.g. if they
# include `np.nan`s
- #assert ser == reser
+ # assert ser == reser
def test_key_stable(self, instance):
"""Check that generating the instance from a dict representation yields
@@ -164,9 +171,7 @@ class TestGufeTokenizable(GufeTokenizableTestsMixin):
@pytest.fixture
def instance(self):
- """Define instance to test with here.
-
- """
+ """Define instance to test with here."""
return self.cont
def setup_method(self):
@@ -176,49 +181,55 @@ def setup_method(self):
self.cont = Container(bar, [leaf, 0], {"leaf": leaf, "a": "b"})
def leaf_dict(a):
- return {'__module__': __name__, '__qualname__': "Leaf", "a": a,
- "b": 2, ':version:': 1}
+ return {
+ "__module__": __name__,
+ "__qualname__": "Leaf",
+ "a": a,
+ "b": 2,
+ ":version:": 1,
+ }
self.expected_deep = {
- '__qualname__': "Container",
- '__module__': __name__,
- 'obj': leaf_dict(leaf_dict("foo")),
- 'lst': [leaf_dict("foo"), 0],
- 'dct': {"leaf": leaf_dict("foo"), "a": "b"},
- ':version:': 1,
+ "__qualname__": "Container",
+ "__module__": __name__,
+ "obj": leaf_dict(leaf_dict("foo")),
+ "lst": [leaf_dict("foo"), 0],
+ "dct": {"leaf": leaf_dict("foo"), "a": "b"},
+ ":version:": 1,
}
self.expected_shallow = {
- '__qualname__': "Container",
- '__module__': __name__,
- 'obj': bar,
- 'lst': [leaf, 0],
- 'dct': {'leaf': leaf, 'a': 'b'},
- ':version:': 1,
+ "__qualname__": "Container",
+ "__module__": __name__,
+ "obj": bar,
+ "lst": [leaf, 0],
+ "dct": {"leaf": leaf, "a": "b"},
+ ":version:": 1,
}
self.expected_keyed = {
- '__qualname__': "Container",
- '__module__': __name__,
- 'obj': {":gufe-key:": bar.key},
- 'lst': [{":gufe-key:": leaf.key}, 0],
- 'dct': {'leaf': {":gufe-key:": leaf.key}, 'a': 'b'},
- ':version:': 1,
+ "__qualname__": "Container",
+ "__module__": __name__,
+ "obj": {":gufe-key:": bar.key},
+ "lst": [{":gufe-key:": leaf.key}, 0],
+ "dct": {"leaf": {":gufe-key:": leaf.key}, "a": "b"},
+ ":version:": 1,
}
self.expected_keyed_chain = [
- (str(leaf.key),
- leaf_dict("foo")),
- (str(bar.key),
- leaf_dict({':gufe-key:': str(leaf.key)})),
- (str(self.cont.key),
- {':version:': 1,
- '__module__': __name__,
- '__qualname__': 'Container',
- 'dct': {'a': 'b',
- 'leaf': {':gufe-key:': str(leaf.key)}},
- 'lst': [{':gufe-key:': str(leaf.key)}, 0],
- 'obj': {':gufe-key:': str(bar.key)}})
+ (str(leaf.key), leaf_dict("foo")),
+ (str(bar.key), leaf_dict({":gufe-key:": str(leaf.key)})),
+ (
+ str(self.cont.key),
+ {
+ ":version:": 1,
+ "__module__": __name__,
+ "__qualname__": "Container",
+ "dct": {"a": "b", "leaf": {":gufe-key:": str(leaf.key)}},
+ "lst": [{":gufe-key:": str(leaf.key)}, 0],
+ "obj": {":gufe-key:": str(bar.key)},
+ },
+ ),
]
def test_set_key(self):
@@ -283,7 +294,11 @@ def test_to_json_file(self, tmpdir):
def test_from_json_file(self, tmpdir):
file_path = tmpdir / "container.json"
- json.dump(self.expected_keyed_chain, file_path.open(mode="w"), cls=JSON_HANDLER.encoder)
+ json.dump(
+ self.expected_keyed_chain,
+ file_path.open(mode="w"),
+ cls=JSON_HANDLER.encoder,
+ )
recreated = self.cls.from_json(file=file_path)
assert recreated == self.cont
@@ -299,7 +314,7 @@ def test_from_shallow_dict(self):
# here we keep the same objects in memory
assert recreated.obj.a is recreated.lst[0]
- assert recreated.obj.a is recreated.dct['leaf']
+ assert recreated.obj.a is recreated.dct["leaf"]
def test_notequal_different_type(self):
l1 = Leaf(4)
@@ -326,13 +341,11 @@ def test_copy_with_replacements_invalid(self):
with pytest.raises(TypeError, match="Invalid"):
_ = l1.copy_with_replacements(foo=10)
- @pytest.mark.parametrize('level', ["DEBUG", "INFO", "CRITICAL"])
+ @pytest.mark.parametrize("level", ["DEBUG", "INFO", "CRITICAL"])
def test_logging(self, level):
stream = io.StringIO()
handler = logging.StreamHandler(stream)
- fmt = logging.Formatter(
- "%(name)s - %(gufekey)s - %(levelname)s - %(message)s"
- )
+ fmt = logging.Formatter("%(name)s - %(gufekey)s - %(levelname)s - %(message)s")
name = "gufekey.gufe.tests.test_tokenization.Leaf"
logger = logging.getLogger(name)
logger.setLevel(getattr(logging, level))
@@ -342,7 +355,7 @@ def test_logging(self, level):
leaf = Leaf(10)
results = stream.getvalue()
- key = leaf.key.split('-')[-1]
+ key = leaf.key.split("-")[-1]
initial_log = f"{name} - UNKNOWN - INFO - no key defined!\n"
info_log = f"{name} - {key} - INFO - a=10\n"
@@ -370,11 +383,14 @@ class Inner:
pass
-@pytest.mark.parametrize('modname, qualname, expected', [
- (__name__, "Outer", Outer),
- (__name__, "Outer.Inner", Outer.Inner),
- ("gufe.tokenization", 'import_qualname', import_qualname),
-])
+@pytest.mark.parametrize(
+ "modname, qualname, expected",
+ [
+ (__name__, "Outer", Outer),
+ (__name__, "Outer.Inner", Outer.Inner),
+ ("gufe.tokenization", "import_qualname", import_qualname),
+ ],
+)
def test_import_qualname(modname, qualname, expected):
assert import_qualname(modname, qualname) is expected
@@ -382,9 +398,9 @@ def test_import_qualname(modname, qualname, expected):
def test_import_qualname_not_yet_imported():
# this is specifically to test that something we don't have imported in
# this module will import correctly
- msg_cls = import_qualname(modname="email.message",
- qualname="EmailMessage")
+ msg_cls = import_qualname(modname="email.message", qualname="EmailMessage")
from email.message import EmailMessage
+
assert msg_cls is EmailMessage
@@ -393,26 +409,33 @@ def test_import_qualname_remappings():
assert import_qualname("foo", "Bar.Baz", remappings) is Outer.Inner
-@pytest.mark.parametrize('modname, qualname', [
- (None, "Outer.Inner"),
- (__name__, None),
-])
+@pytest.mark.parametrize(
+ "modname, qualname",
+ [
+ (None, "Outer.Inner"),
+ (__name__, None),
+ ],
+)
def test_import_qualname_error_none(modname, qualname):
with pytest.raises(ValueError, match="cannot be None"):
import_qualname(modname, qualname)
-
-@pytest.mark.parametrize('cls_reg', [
- {},
- {(__name__, "Outer.Inner"): Outer.Inner},
-])
+@pytest.mark.parametrize(
+ "cls_reg",
+ [
+ {},
+ {(__name__, "Outer.Inner"): Outer.Inner},
+ ],
+)
def test_get_class(cls_reg):
with mock.patch.dict("gufe.tokenization.TOKENIZABLE_CLASS_REGISTRY", cls_reg):
assert get_class(__name__, "Outer.Inner") is Outer.Inner
+
def test_path_to_json():
import pathlib
+
p = pathlib.Path("foo/bar")
ser = json.dumps(p, cls=JSON_HANDLER.encoder)
deser = json.loads(ser, cls=JSON_HANDLER.decoder)
@@ -423,27 +446,25 @@ def test_path_to_json():
class TestGufeKey:
def test_to_dict(self):
- k = GufeKey('foo-bar')
+ k = GufeKey("foo-bar")
- assert k.to_dict() == {':gufe-key:': 'foo-bar'}
+ assert k.to_dict() == {":gufe-key:": "foo-bar"}
def test_prefix(self):
- k = GufeKey('foo-bar')
+ k = GufeKey("foo-bar")
- assert k.prefix == 'foo'
+ assert k.prefix == "foo"
def test_token(self):
- k = GufeKey('foo-bar')
+ k = GufeKey("foo-bar")
- assert k.token == 'bar'
+ assert k.token == "bar"
def test_gufe_to_digraph(solvated_complex):
graph = gufe_to_digraph(solvated_complex)
- connected_objects = gufe_objects_from_shallow_dict(
- solvated_complex.to_shallow_dict()
- )
+ connected_objects = gufe_objects_from_shallow_dict(solvated_complex.to_shallow_dict())
assert len(graph.nodes) == 4
assert len(graph.edges) == 3
@@ -472,9 +493,7 @@ def test_from_gufe(self, benzene_variants_star_map):
assert len(kc) == expected_len
original_keys = [obj.key for obj in contained_objects]
- original_keyed_dicts = [
- obj.to_keyed_dict() for obj in contained_objects
- ]
+ original_keyed_dicts = [obj.to_keyed_dict() for obj in contained_objects]
kc_gufe_keys = set(kc.gufe_keys())
kc_keyed_dicts = list(kc.keyed_dicts())
@@ -498,7 +517,7 @@ def test_get_item(self, benzene_variants_star_map):
def test_datetime_to_json():
- d = datetime.datetime.fromisoformat('2023-05-05T09:06:43.699068')
+ d = datetime.datetime.fromisoformat("2023-05-05T09:06:43.699068")
ser = json.dumps(d, cls=JSON_HANDLER.encoder)
diff --git a/gufe/tests/test_transformation.py b/gufe/tests/test_transformation.py
index 3ad29c59..445ff651 100644
--- a/gufe/tests/test_transformation.py
+++ b/gufe/tests/test_transformation.py
@@ -1,13 +1,14 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
-import pytest
import io
import pathlib
+import pytest
+
import gufe
-from gufe.transformations import Transformation, NonTransformation
from gufe.protocols.protocoldag import execute_DAG
+from gufe.transformations import NonTransformation, Transformation
from .test_protocol import DummyProtocol, DummyProtocolResult
from .test_tokenization import GufeTokenizableTestsMixin
@@ -25,8 +26,10 @@ def absolute_transformation(solvated_ligand, solvated_complex):
@pytest.fixture
def complex_equilibrium(solvated_complex):
- return NonTransformation(solvated_complex,
- protocol=DummyProtocol(settings=DummyProtocol.default_settings()))
+ return NonTransformation(
+ solvated_complex,
+ protocol=DummyProtocol(settings=DummyProtocol.default_settings()),
+ )
class TestTransformation(GufeTokenizableTestsMixin):
@@ -51,12 +54,12 @@ def test_protocol(self, absolute_transformation, tmpdir):
protocoldag = tnf.create()
with tmpdir.as_cwd():
- shared = pathlib.Path('shared')
+ shared = pathlib.Path("shared")
shared.mkdir(parents=True)
- scratch = pathlib.Path('scratch')
+ scratch = pathlib.Path("scratch")
scratch.mkdir(parents=True)
-
+
protocoldagresult = execute_DAG(protocoldag, shared_basedir=shared, scratch_basedir=scratch)
protocolresult = tnf.gather([protocoldagresult])
@@ -64,8 +67,8 @@ def test_protocol(self, absolute_transformation, tmpdir):
assert isinstance(protocolresult, DummyProtocolResult)
assert len(protocolresult.data) == 2
- assert 'logs' in protocolresult.data
- assert 'key_results' in protocolresult.data
+ assert "logs" in protocolresult.data
+ assert "key_results" in protocolresult.data
def test_protocol_extend(self, absolute_transformation, tmpdir):
tnf = absolute_transformation
@@ -73,12 +76,12 @@ def test_protocol_extend(self, absolute_transformation, tmpdir):
assert isinstance(tnf.protocol, DummyProtocol)
with tmpdir.as_cwd():
- shared = pathlib.Path('shared')
+ shared = pathlib.Path("shared")
shared.mkdir(parents=True)
- scratch = pathlib.Path('scratch')
+ scratch = pathlib.Path("scratch")
scratch.mkdir(parents=True)
-
+
protocoldag = tnf.create()
protocoldagresult = execute_DAG(protocoldag, shared_basedir=shared, scratch_basedir=scratch)
@@ -94,8 +97,9 @@ def test_protocol_extend(self, absolute_transformation, tmpdir):
def test_equality(self, absolute_transformation, solvated_ligand, solvated_complex):
opposite = Transformation(
- solvated_complex, solvated_ligand,
- protocol=DummyProtocol(settings=DummyProtocol.default_settings())
+ solvated_complex,
+ solvated_ligand,
+ protocol=DummyProtocol(settings=DummyProtocol.default_settings()),
)
assert absolute_transformation != opposite
@@ -124,16 +128,16 @@ def test_dump_load_roundtrip(self, absolute_transformation):
assert absolute_transformation == recreated
def test_deprecation_warning_on_dict_mapping(self, solvated_ligand, solvated_complex):
- lig = solvated_complex.components['ligand']
+ lig = solvated_complex.components["ligand"]
# this mapping makes no sense, but it'll trigger the dep warning we want
mapping = gufe.LigandAtomMapping(lig, lig, componentA_to_componentB={})
- with pytest.warns(DeprecationWarning,
- match="mapping input as a dict is deprecated"):
+ with pytest.warns(DeprecationWarning, match="mapping input as a dict is deprecated"):
Transformation(
- solvated_complex, solvated_ligand,
+ solvated_complex,
+ solvated_ligand,
protocol=DummyProtocol(settings=DummyProtocol.default_settings()),
- mapping={'ligand': mapping},
+ mapping={"ligand": mapping},
)
@@ -160,10 +164,10 @@ def test_protocol(self, complex_equilibrium, tmpdir):
protocoldag = ntnf.create()
with tmpdir.as_cwd():
- shared = pathlib.Path('shared')
+ shared = pathlib.Path("shared")
shared.mkdir(parents=True)
- scratch = pathlib.Path('scratch')
+ scratch = pathlib.Path("scratch")
scratch.mkdir(parents=True)
protocoldagresult = execute_DAG(protocoldag, shared_basedir=shared, scratch_basedir=scratch)
@@ -173,8 +177,8 @@ def test_protocol(self, complex_equilibrium, tmpdir):
assert isinstance(protocolresult, DummyProtocolResult)
assert len(protocolresult.data) == 2
- assert 'logs' in protocolresult.data
- assert 'key_results' in protocolresult.data
+ assert "logs" in protocolresult.data
+ assert "key_results" in protocolresult.data
def test_protocol_extend(self, complex_equilibrium, tmpdir):
ntnf = complex_equilibrium
@@ -182,10 +186,10 @@ def test_protocol_extend(self, complex_equilibrium, tmpdir):
assert isinstance(ntnf.protocol, DummyProtocol)
with tmpdir.as_cwd():
- shared = pathlib.Path('shared')
+ shared = pathlib.Path("shared")
shared.mkdir(parents=True)
- scratch = pathlib.Path('scratch')
+ scratch = pathlib.Path("scratch")
scratch.mkdir(parents=True)
protocoldag = ntnf.create()
@@ -203,17 +207,17 @@ def test_protocol_extend(self, complex_equilibrium, tmpdir):
def test_equality(self, complex_equilibrium, solvated_ligand, solvated_complex):
s = DummyProtocol.default_settings()
s.n_repeats = 4031
- different_protocol_settings = NonTransformation(
- solvated_complex, protocol=DummyProtocol(settings=s)
- )
+ different_protocol_settings = NonTransformation(solvated_complex, protocol=DummyProtocol(settings=s))
assert complex_equilibrium != different_protocol_settings
identical = NonTransformation(
- solvated_complex, protocol=DummyProtocol(settings=DummyProtocol.default_settings())
+ solvated_complex,
+ protocol=DummyProtocol(settings=DummyProtocol.default_settings()),
)
assert complex_equilibrium == identical
different_system = NonTransformation(
- solvated_ligand, protocol=DummyProtocol(settings=DummyProtocol.default_settings())
+ solvated_ligand,
+ protocol=DummyProtocol(settings=DummyProtocol.default_settings()),
)
assert complex_equilibrium != different_system
diff --git a/gufe/tests/test_utils.py b/gufe/tests/test_utils.py
index 8bd83749..1e2ec962 100644
--- a/gufe/tests/test_utils.py
+++ b/gufe/tests/test_utils.py
@@ -1,30 +1,28 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
-import pytest
-
import io
import pathlib
+import pytest
+
from gufe.utils import ensure_filelike
-@pytest.mark.parametrize('input_type', [
- "str", "path", "TextIO", "BytesIO", "StringIO"
-])
+@pytest.mark.parametrize("input_type", ["str", "path", "TextIO", "BytesIO", "StringIO"])
def test_ensure_filelike(input_type, tmp_path):
path = tmp_path / "foo.txt"
# we choose to use bytes for pathlib.Path just to mix things up;
# string filename or path can be either bytes or string, so we give one
# to each
- use_bytes = input_type in {'path', 'BytesIO'}
- filelike = input_type not in {'str', 'path'}
+ use_bytes = input_type in {"path", "BytesIO"}
+ filelike = input_type not in {"str", "path"}
dumper = {
- 'str': str(path),
- 'path': path,
- 'TextIO': open(path, mode='w'),
- 'BytesIO': open(path, mode='wb'),
- 'StringIO': io.StringIO(),
+ "str": str(path),
+ "path": path,
+ "TextIO": open(path, mode="w"),
+ "BytesIO": open(path, mode="wb"),
+ "StringIO": io.StringIO(),
}[input_type]
if filelike:
@@ -40,15 +38,15 @@ def test_ensure_filelike(input_type, tmp_path):
write_f.write(written)
write_f.flush()
- if input_type == 'StringIO':
+ if input_type == "StringIO":
dumper.seek(0)
loader = {
- 'str': str(path),
- 'path': path,
- 'TextIO': open(path, mode='r'),
- 'BytesIO': open(path, mode='rb'),
- 'StringIO': dumper,
+ "str": str(path),
+ "path": path,
+ "TextIO": open(path),
+ "BytesIO": open(path, mode="rb"),
+ "StringIO": dumper,
}[input_type]
with ensure_filelike(loader, mode=read_mode) as read_f:
@@ -64,13 +62,14 @@ def test_ensure_filelike(input_type, tmp_path):
write_f.close()
read_f.close()
+
@pytest.mark.parametrize("input_type", ["TextIO", "BytesIO", "StringIO"])
def test_ensure_filelike_force_close(input_type, tmp_path):
path = tmp_path / "foo.txt"
dumper = {
- 'TextIO': open(path, mode='w'),
- 'BytesIO': open(path, mode='wb'),
- 'StringIO': io.StringIO(),
+ "TextIO": open(path, mode="w"),
+ "BytesIO": open(path, mode="wb"),
+ "StringIO": io.StringIO(),
}[input_type]
written = b"foo" if input_type == "BytesIO" else "foo"
@@ -79,22 +78,23 @@ def test_ensure_filelike_force_close(input_type, tmp_path):
assert f.closed
+
@pytest.mark.parametrize("input_type", ["TextIO", "BytesIO", "StringIO"])
def test_ensure_filelike_mode_warning(input_type, tmp_path):
path = tmp_path / "foo.txt"
dumper = {
- 'TextIO': open(path, mode='w'),
- 'BytesIO': open(path, mode='wb'),
- 'StringIO': io.StringIO(),
+ "TextIO": open(path, mode="w"),
+ "BytesIO": open(path, mode="wb"),
+ "StringIO": io.StringIO(),
}[input_type]
- with pytest.warns(UserWarning,
- match="User-specified mode will be ignored"):
+ with pytest.warns(UserWarning, match="User-specified mode will be ignored"):
_ = ensure_filelike(dumper, mode="w")
dumper.close()
+
def test_ensure_filelike_default_mode():
path = "foo.txt"
loader = ensure_filelike(path)
- assert loader.mode == 'r'
+ assert loader.mode == "r"
diff --git a/gufe/tokenization.py b/gufe/tokenization.py
index e7c959db..9ace1bf0 100644
--- a/gufe/tokenization.py
+++ b/gufe/tokenization.py
@@ -9,13 +9,15 @@
import inspect
import json
import logging
-import networkx as nx
import re
import warnings
import weakref
+from collections.abc import Generator
from itertools import chain
from os import PathLike
-from typing import Any, Union, List, Tuple, Dict, Generator, TextIO, Optional
+from typing import Any, Dict, List, Optional, TextIO, Tuple, Union
+
+import networkx as nx
from typing_extensions import Self
from gufe.custom_codecs import (
@@ -115,11 +117,12 @@ class _GufeLoggerAdapter(logging.LoggerAdapter):
extra: :class:`.GufeTokenizable`
the instance this adapter is associated with
"""
+
def process(self, msg, kwargs):
- extra = kwargs.get('extra', {})
- if (extra_dict := getattr(self, '_extra_dict', None)) is None:
+ extra = kwargs.get("extra", {})
+ if (extra_dict := getattr(self, "_extra_dict", None)) is None:
try:
- gufekey = self.extra.key.split('-')[-1]
+ gufekey = self.extra.key.split("-")[-1]
except Exception:
# no matter what happened, we have a bad key
gufekey = "UNKNOWN"
@@ -127,15 +130,13 @@ def process(self, msg, kwargs):
else:
save_extra_dict = True
- extra_dict = {
- 'gufekey': gufekey
- }
+ extra_dict = {"gufekey": gufekey}
if save_extra_dict:
self._extra_dict = extra_dict
extra.update(extra_dict)
- kwargs['extra'] = extra
+ kwargs["extra"] = extra
return msg, kwargs
@@ -193,8 +194,9 @@ def old_key_removed(dct, old_key, should_warn):
if should_warn:
# TODO: this should be put elsewhere so that the warning can be more
# meaningful (somewhere that knows what class we're recreating)
- warnings.warn(f"Outdated serialization: '{old_key}', with value "
- f"'{dct[old_key]}' is no longer used in this object")
+ warnings.warn(
+ f"Outdated serialization: '{old_key}', with value " f"'{dct[old_key]}' is no longer used in this object"
+ )
del dct[old_key]
return dct
@@ -231,6 +233,7 @@ def _label_to_parts(label):
See :func:`.nested_key_moved` for a description of the label.
"""
+
def _intify_if_possible(part):
try:
part = int(part)
@@ -238,10 +241,8 @@ def _intify_if_possible(part):
pass
return part
- parts = [
- _intify_if_possible(p) for p in re.split('\.|\[|\]', label)
- if p != ""
- ]
+
+ parts = [_intify_if_possible(p) for p in re.split(r"\.|\[|\]", label) if p != ""]
return parts
@@ -326,6 +327,7 @@ class GufeTokenizable(abc.ABC, metaclass=_ABCGufeClassMeta):
This extra work in serializing is important for hashes that are stable
*across different Python sessions*.
"""
+
@classmethod
def _schema_version(cls) -> int:
return 1
@@ -346,9 +348,7 @@ def __hash__(self):
return hash(self.key)
def _gufe_tokenize(self):
- """Return a list of normalized inputs for `gufe.base.tokenize`.
-
- """
+ """Return a list of normalized inputs for `gufe.base.tokenize`."""
return tokenize(self)
# return normalize(self.to_keyed_dict(include_defaults=False))
@@ -405,7 +405,7 @@ def serialization_migration(cls, old_dict, version):
@property
def logger(self):
"""Return logger adapter for this instance"""
- if (adapter := getattr(self, '_logger', None)) is None:
+ if (adapter := getattr(self, "_logger", None)) is None:
cls = self.__class__
logname = "gufekey." + cls.__module__ + "." + cls.__qualname__
logger = logging.getLogger(logname)
@@ -415,7 +415,7 @@ def logger(self):
@property
def key(self):
- if not hasattr(self, '_key') or self._key is None:
+ if not hasattr(self, "_key") or self._key is None:
prefix = self.__class__.__qualname__
token = self._gufe_tokenize()
self._key = GufeKey(f"{prefix}-{token}")
@@ -433,7 +433,7 @@ def _set_key(self, key: str):
key : str
contents of the GufeKey for this object
"""
- if old_key := getattr(self, '_key', None):
+ if old_key := getattr(self, "_key", None):
TOKENIZABLE_REGISTRY.pop(old_key)
self._key = GufeKey(key)
@@ -448,7 +448,7 @@ def defaults(cls):
"""
defaults = cls._defaults()
- defaults[':version:'] = cls._schema_version()
+ defaults[":version:"] = cls._schema_version()
return defaults
@classmethod
@@ -461,7 +461,8 @@ def _defaults(cls):
sig = inspect.signature(cls.__init__)
defaults = {
- param.name: param.default for param in sig.parameters.values()
+ param.name: param.default
+ for param in sig.parameters.values()
if param.default is not inspect.Parameter.empty
}
@@ -617,13 +618,12 @@ def copy_with_replacements(self, **replacements):
"""
dct = self._to_dict()
if invalid := set(replacements) - set(dct):
- raise TypeError(f"Invalid replacement keys: {invalid}. "
- f"Allowed keys are: {set(dct)}")
+ raise TypeError(f"Invalid replacement keys: {invalid}. " f"Allowed keys are: {set(dct)}")
dct.update(replacements)
return self._from_dict(dct)
- def to_keyed_chain(self) -> List[Tuple[str, Dict]]:
+ def to_keyed_chain(self) -> list[tuple[str, dict]]:
"""
Generate a keyed chain representation of the object.
@@ -634,7 +634,7 @@ def to_keyed_chain(self) -> List[Tuple[str, Dict]]:
return KeyedChain.gufe_to_keyed_chain_rep(self)
@classmethod
- def from_keyed_chain(cls, keyed_chain: List[Tuple[str, Dict]]):
+ def from_keyed_chain(cls, keyed_chain: list[tuple[str, dict]]):
"""
Generate an instance from keyed chain representation.
@@ -674,6 +674,7 @@ def to_json(self, file: Optional[PathLike | TextIO] = None) -> None | str:
return json.dumps(self.to_keyed_chain(), cls=JSON_HANDLER.encoder)
from gufe.utils import ensure_filelike
+
with ensure_filelike(file, mode="w") as out:
json.dump(self.to_keyed_chain(), out, cls=JSON_HANDLER.encoder)
@@ -708,6 +709,7 @@ def from_json(cls, file: Optional[PathLike | TextIO] = None, content: Optional[s
return cls.from_keyed_chain(keyed_chain=keyed_chain)
from gufe.utils import ensure_filelike
+
with ensure_filelike(file, mode="r") as f:
keyed_chain = json.load(f, cls=JSON_HANDLER.decoder)
@@ -715,26 +717,24 @@ def from_json(cls, file: Optional[PathLike | TextIO] = None, content: Optional[s
class GufeKey(str):
- def __repr__(self): # pragma: no cover
+ def __repr__(self): # pragma: no cover
return f""
def to_dict(self):
- return {':gufe-key:': str(self)}
+ return {":gufe-key:": str(self)}
@property
def prefix(self) -> str:
"""Commonly indicates a classname"""
- return self.split('-')[0]
+ return self.split("-")[0]
@property
def token(self) -> str:
"""Unique hash of this key, typically a md5 value"""
- return self.split('-')[1]
+ return self.split("-")[1]
-def gufe_objects_from_shallow_dict(
- obj: Union[List, Dict, GufeTokenizable]
-) -> List[GufeTokenizable]:
+def gufe_objects_from_shallow_dict(obj: Union[list, dict, GufeTokenizable]) -> list[GufeTokenizable]:
"""Find GufeTokenizables within a shallow dict.
This function recursively looks through the list/dict structures encoding
@@ -759,16 +759,10 @@ def gufe_objects_from_shallow_dict(
return [obj]
elif isinstance(obj, list):
- return list(
- chain.from_iterable([gufe_objects_from_shallow_dict(item) for item in obj])
- )
+ return list(chain.from_iterable([gufe_objects_from_shallow_dict(item) for item in obj]))
elif isinstance(obj, dict):
- return list(
- chain.from_iterable(
- [gufe_objects_from_shallow_dict(item) for item in obj.values()]
- )
- )
+ return list(chain.from_iterable([gufe_objects_from_shallow_dict(item) for item in obj.values()]))
return []
@@ -811,7 +805,7 @@ def add_edges(o):
return graph
-class KeyedChain(object):
+class KeyedChain:
"""Keyed chain representation encoder of a GufeTokenizable.
The keyed chain representation of a GufeTokenizable provides a
@@ -859,25 +853,25 @@ def from_gufe(cls, gufe_object: GufeTokenizable) -> Self:
def to_gufe(self) -> GufeTokenizable:
"""Initialize a GufeTokenizable."""
- gts: Dict[str, GufeTokenizable] = {}
+ gts: dict[str, GufeTokenizable] = {}
for gufe_key, keyed_dict in self:
gt = key_decode_dependencies(keyed_dict, registry=gts)
gts[gufe_key] = gt
return gt
@classmethod
- def from_keyed_chain_rep(cls, keyed_chain: List[Tuple[str, Dict]]) -> Self:
+ def from_keyed_chain_rep(cls, keyed_chain: list[tuple[str, dict]]) -> Self:
"""Initialize a KeyedChain from a keyed chain representation."""
return cls(keyed_chain)
- def to_keyed_chain_rep(self) -> List[Tuple[str, Dict]]:
+ def to_keyed_chain_rep(self) -> list[tuple[str, dict]]:
"""Return the keyed chain representation of this object."""
return list(self)
@staticmethod
def gufe_to_keyed_chain_rep(
gufe_object: GufeTokenizable,
- ) -> List[Tuple[str, Dict]]:
+ ) -> list[tuple[str, dict]]:
"""Create the keyed chain representation of a GufeTokenizable.
This represents the GufeTokenizable as a list of two-element tuples
@@ -897,8 +891,7 @@ def gufe_to_keyed_chain_rep(
"""
key_and_keyed_dicts = [
- (str(gt.key), gt.to_keyed_dict())
- for gt in nx.topological_sort(gufe_to_digraph(gufe_object))
+ (str(gt.key), gt.to_keyed_dict()) for gt in nx.topological_sort(gufe_to_digraph(gufe_object))
][::-1]
return key_and_keyed_dicts
@@ -907,7 +900,7 @@ def gufe_keys(self) -> Generator[str, None, None]:
for key, _ in self:
yield key
- def keyed_dicts(self) -> Generator[Dict, None, None]:
+ def keyed_dicts(self) -> Generator[dict, None, None]:
"""Create a generator that iterates over the keyed dicts in the KeyedChain."""
for _, _dict in self:
yield _dict
@@ -936,8 +929,10 @@ def __getitem__(self, index):
def module_qualname(obj):
- return {'__qualname__': obj.__class__.__qualname__,
- '__module__': obj.__class__.__module__}
+ return {
+ "__qualname__": obj.__class__.__qualname__,
+ "__module__": obj.__class__.__module__,
+ }
def is_gufe_obj(obj: Any):
@@ -945,8 +940,7 @@ def is_gufe_obj(obj: Any):
def is_gufe_dict(dct: Any):
- return (isinstance(dct, dict) and '__qualname__' in dct
- and '__module__' in dct)
+ return isinstance(dct, dict) and "__qualname__" in dct and "__module__" in dct
def is_gufe_key_dict(dct: Any):
@@ -956,14 +950,15 @@ def is_gufe_key_dict(dct: Any):
# conveniences to get a class from module/class name
def import_qualname(modname: str, qualname: str, remappings=REMAPPED_CLASSES):
if (qualname is None) or (modname is None):
- raise ValueError("`__qualname__` or `__module__` cannot be None; "
- f"unable to identify object {modname}.{qualname}")
+ raise ValueError(
+ "`__qualname__` or `__module__` cannot be None; " f"unable to identify object {modname}.{qualname}"
+ )
if (modname, qualname) in remappings:
modname, qualname = remappings[(modname, qualname)]
result = importlib.import_module(modname)
- for name in qualname.split('.'):
+ for name in qualname.split("."):
result = getattr(result, name)
return result
@@ -1000,18 +995,16 @@ def modify_dependencies(obj: Union[dict, list], modifier, is_mine, mode, top=Tru
If `True`, skip modifying `obj` itself; needed for recursive use to
avoid early stopping on `obj`.
"""
- if is_mine(obj) and not top and mode == 'encode':
+ if is_mine(obj) and not top and mode == "encode":
obj = modifier(obj)
if isinstance(obj, dict):
- obj = {key: modify_dependencies(value, modifier, is_mine, mode=mode, top=False)
- for key, value in obj.items()}
+ obj = {key: modify_dependencies(value, modifier, is_mine, mode=mode, top=False) for key, value in obj.items()}
elif isinstance(obj, list):
- obj = [modify_dependencies(item, modifier, is_mine, mode=mode, top=False)
- for item in obj]
+ obj = [modify_dependencies(item, modifier, is_mine, mode=mode, top=False) for item in obj]
- if is_mine(obj) and not top and mode == 'decode':
+ if is_mine(obj) and not top and mode == "decode":
obj = modifier(obj)
return obj
@@ -1021,18 +1014,12 @@ def modify_dependencies(obj: Union[dict, list], modifier, is_mine, mode, top=Tru
def to_dict(obj: GufeTokenizable) -> dict:
dct = obj._to_dict()
dct.update(module_qualname(obj))
- dct[':version:'] = obj._schema_version()
+ dct[":version:"] = obj._schema_version()
return dct
def dict_encode_dependencies(obj: GufeTokenizable) -> dict:
- return modify_dependencies(
- obj.to_shallow_dict(),
- to_dict,
- is_gufe_obj,
- mode='encode',
- top=True
- )
+ return modify_dependencies(obj.to_shallow_dict(), to_dict, is_gufe_obj, mode="encode", top=True)
def key_encode_dependencies(obj: GufeTokenizable) -> dict:
@@ -1040,8 +1027,8 @@ def key_encode_dependencies(obj: GufeTokenizable) -> dict:
obj.to_shallow_dict(),
lambda obj: obj.key.to_dict(),
is_gufe_obj,
- mode='encode',
- top=True
+ mode="encode",
+ top=True,
)
@@ -1064,9 +1051,9 @@ def from_dict(dct) -> GufeTokenizable:
def _from_dict(dct: dict) -> GufeTokenizable:
- module = dct.pop('__module__')
- qualname = dct.pop('__qualname__')
- version = dct.pop(':version:', 1)
+ module = dct.pop("__module__")
+ qualname = dct.pop("__qualname__")
+ version = dct.pop(":version:", 1)
cls = get_class(module, qualname)
dct = cls.serialization_migration(dct, version)
@@ -1074,23 +1061,18 @@ def _from_dict(dct: dict) -> GufeTokenizable:
def dict_decode_dependencies(dct: dict) -> GufeTokenizable:
- return from_dict(
- modify_dependencies(dct, from_dict, is_gufe_dict, mode='decode', top=True)
- )
+ return from_dict(modify_dependencies(dct, from_dict, is_gufe_dict, mode="decode", top=True))
-def key_decode_dependencies(
- dct: dict,
- registry=TOKENIZABLE_REGISTRY
-) -> GufeTokenizable:
+def key_decode_dependencies(dct: dict, registry=TOKENIZABLE_REGISTRY) -> GufeTokenizable:
# this version requires that all dependent objects are already registered
# responsibility of the storage system that uses this to do so
dct = modify_dependencies(
dct,
lambda d: registry[GufeKey(d[":gufe-key:"])],
is_gufe_key_dict,
- mode='decode',
- top=True
+ mode="decode",
+ top=True,
)
return from_dict(dct)
@@ -1111,12 +1093,12 @@ def get_all_gufe_objs(obj):
all contained GufeTokenizables
"""
results = {obj}
+
def modifier(o):
results.add(o)
return o.to_shallow_dict()
- _ = modify_dependencies(obj.to_shallow_dict(), modifier, is_gufe_obj,
- mode='encode')
+ _ = modify_dependencies(obj.to_shallow_dict(), modifier, is_gufe_obj, mode="encode")
return results
@@ -1136,7 +1118,10 @@ def tokenize(obj: GufeTokenizable) -> str:
"""
# hasher = hashlib.md5(str(normalize(obj)).encode(), usedforsecurity=False)
- dumped = json.dumps(obj.to_keyed_dict(include_defaults=False),
- sort_keys=True, cls=JSON_HANDLER.encoder)
+ dumped = json.dumps(
+ obj.to_keyed_dict(include_defaults=False),
+ sort_keys=True,
+ cls=JSON_HANDLER.encoder,
+ )
hasher = hashlib.md5(dumped.encode(), usedforsecurity=False)
return hasher.hexdigest()
diff --git a/gufe/transformations/__init__.py b/gufe/transformations/__init__.py
index fe016384..2bb0e7b5 100644
--- a/gufe/transformations/__init__.py
+++ b/gufe/transformations/__init__.py
@@ -1,2 +1,3 @@
"""A chemical system and protocol combined form a Transformation"""
-from .transformation import Transformation, NonTransformation
+
+from .transformation import NonTransformation, Transformation
diff --git a/gufe/transformations/transformation.py b/gufe/transformations/transformation.py
index e018fdc5..e8df7a3b 100644
--- a/gufe/transformations/transformation.py
+++ b/gufe/transformations/transformation.py
@@ -2,27 +2,34 @@
# For details, see https://github.com/OpenFreeEnergy/gufe
import abc
-from typing import Optional, Iterable, Union
import json
import warnings
-
-from ..tokenization import GufeTokenizable, JSON_HANDLER
-from ..utils import ensure_filelike
+from collections.abc import Iterable
+from typing import Optional, Union
from ..chemicalsystem import ChemicalSystem
-from ..protocols import Protocol, ProtocolDAG, ProtocolResult, ProtocolDAGResult
from ..mapping import ComponentMapping
+from ..protocols import Protocol, ProtocolDAG, ProtocolDAGResult, ProtocolResult
+from ..tokenization import JSON_HANDLER, GufeTokenizable
+from ..utils import ensure_filelike
class TransformationBase(GufeTokenizable):
- """Transformation base class.
-
- """
def __init__(
self,
protocol: Protocol,
name: Optional[str] = None,
):
+ """Transformation base class.
+
+ Parameters
+ ----------
+ protocol : Protocol
+ The sampling method to use for the transformation.
+ name : str, optional
+ A human-readable name for this transformation.
+
+ """
self._protocol = protocol
self._name = name
@@ -32,11 +39,10 @@ def _defaults(cls):
@property
def name(self) -> Optional[str]:
- """
- Optional identifier for the transformation; used as part of its hash.
+ """Optional identifier for the transformation; used as part of its hash.
Set this to a unique value if adding multiple, otherwise identical
- transformations to the same :class:`AlchemicalNetwork` to avoid
+ transformations to the same :class:`.AlchemicalNetwork` to avoid
deduplication.
"""
return self._name
@@ -45,6 +51,18 @@ def name(self) -> Optional[str]:
def _from_dict(cls, d: dict):
return cls(**d)
+ @property
+ @abc.abstractmethod
+ def stateA(self) -> ChemicalSystem:
+ """The starting :class:`.ChemicalSystem` for the transformation."""
+ raise NotImplementedError
+
+ @property
+ @abc.abstractmethod
+ def stateB(self) -> ChemicalSystem:
+ """The ending :class:`.ChemicalSystem` for the transformation."""
+ raise NotImplementedError
+
@abc.abstractmethod
def create(
self,
@@ -53,27 +71,25 @@ def create(
name: Optional[str] = None,
) -> ProtocolDAG:
"""
- Returns a ``ProtocolDAG`` executing this ``Transformation.protocol``.
+ Returns a :class:`.ProtocolDAG` executing this ``Transformation.protocol``.
"""
raise NotImplementedError
- def gather(
- self, protocol_dag_results: Iterable[ProtocolDAGResult]
- ) -> ProtocolResult:
- """
- Gather multiple ``ProtocolDAGResult`` into a single ``ProtocolResult``.
+ def gather(self, protocol_dag_results: Iterable[ProtocolDAGResult]) -> ProtocolResult:
+ """Gather multiple :class:`.ProtocolDAGResult` \s into a single
+ :class:`.ProtocolResult`.
Parameters
----------
protocol_dag_results : Iterable[ProtocolDAGResult]
- The ``ProtocolDAGResult`` objects to assemble aggregate quantities
- from.
+ The :class:`.ProtocolDAGResult` objects to assemble aggregate
+ quantities from.
Returns
-------
ProtocolResult
- Aggregated results from many ``ProtocolDAGResult`` objects, all from
- a given ``Protocol``.
+ Aggregated results from many :class:`.ProtocolDAGResult` objects,
+ all from a given :class:`.Protocol`.
"""
return self.protocol.gather(protocol_dag_results=protocol_dag_results)
@@ -81,18 +97,17 @@ def gather(
def dump(self, file):
"""Dump this Transformation to a JSON file.
- Note that this is not space-efficient: for example, any
- ``Component`` which is used in both ``ChemicalSystem`` objects will be
- represented twice in the JSON output.
+ Note that this is not space-efficient: for example, any ``Component``
+ which is used in both ``ChemicalSystem`` objects will be represented
+ twice in the JSON output.
Parameters
----------
file : Union[PathLike, FileLike]
- a pathlike of filelike to save this transformation to.
+ A pathlike of filelike to save this transformation to.
"""
- with ensure_filelike(file, mode='w') as f:
- json.dump(self.to_dict(), f, cls=JSON_HANDLER.encoder,
- sort_keys=True)
+ with ensure_filelike(file, mode="w") as f:
+ json.dump(self.to_dict(), f, cls=JSON_HANDLER.encoder, sort_keys=True)
@classmethod
def load(cls, file):
@@ -101,9 +116,9 @@ def load(cls, file):
Parameters
----------
file : Union[PathLike, FileLike]
- a pathlike or filelike to read this transformation from
+ A pathlike or filelike to read this transformation from.
"""
- with ensure_filelike(file, mode='r') as f:
+ with ensure_filelike(file, mode="r") as f:
dct = json.load(f, cls=JSON_HANDLER.decoder)
return cls.from_dict(dct)
@@ -124,25 +139,27 @@ def __init__(
mapping: Optional[Union[ComponentMapping, list[ComponentMapping], dict[str, ComponentMapping]]] = None,
name: Optional[str] = None,
):
- """Two chemical states with a method for estimating free energy difference
+ """Two chemical states with a method for estimating the free energy
+ difference between them.
- Connects two :class:`.ChemicalSystem` objects, with directionality,
- and relates this to a :class:`.Protocol` which will provide an estimate of
- the free energy difference of moving between these systems.
- Used as an edge of an :class:`.AlchemicalNetwork`.
+ Connects two :class:`.ChemicalSystem` objects, with directionality, and
+ relates these to a :class:`.Protocol` which will provide an estimate of
+ the free energy difference between these systems. Used as an edge of an
+ :class:`.AlchemicalNetwork`.
Parameters
----------
- stateA, stateB: ChemicalSystem
- The start (A) and end (B) states of the transformation
- protocol: Protocol
- The method used to estimate the free energy difference between states
- A and B
+ stateA, stateB : ChemicalSystem
+ The start (A) and end (B) states of the transformation.
+ protocol : Protocol
+ The method used to estimate the free energy difference between
+ states A and B.
mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]]
- the details of any transformations between :class:`.Component` \s of
- the two states
+ The details of any transformations between :class:`.Component` \s
+ of the two states.
name : str, optional
- a human-readable tag for this transformation
+ A human-readable name for this transformation.
+
"""
if isinstance(mapping, dict):
warnings.warn(("mapping input as a dict is deprecated, "
@@ -157,9 +174,7 @@ def __init__(
self._name = name
def __repr__(self):
- attrs = ['name', 'stateA', 'stateB', 'protocol', 'mapping']
- content = ", ".join([f"{i}={getattr(self, i)}" for i in attrs])
- return f"{self.__class__.__name__}({content})"
+ return f"{self.__class__.__name__}(stateA={self.stateA}, " f"stateB={self.stateB}, protocol={self.protocol})"
@property
def stateA(self) -> ChemicalSystem:
@@ -173,7 +188,7 @@ def stateB(self) -> ChemicalSystem:
@property
def protocol(self) -> Protocol:
- """The protocol used to perform the transformation.
+ """The :class:`.Protocol` used to perform the transformation.
This protocol estimates the free energy differences between ``stateA``
and ``stateB`` :class:`.ChemicalSystem` objects. It includes all details
@@ -216,17 +231,6 @@ def create(
class NonTransformation(TransformationBase):
- """A non-alchemical edge of an alchemical network.
-
- A "transformation" that performs no transformation at all.
- Technically a self-loop, or an edge with the same ``ChemicalSystem`` at
- either end.
-
- Functionally used for applying a dynamics protocol to a ``ChemicalSystem``
- that performs no alchemical transformation at all. This allows e.g.
- equilibrium MD to be performed on a ``ChemicalSystem`` as desired alongside
- alchemical protocols between it and and other ``ChemicalSystem`` objects.
- """
def __init__(
self,
@@ -234,6 +238,27 @@ def __init__(
protocol: Protocol,
name: Optional[str] = None,
):
+ """A non-alchemical edge of an alchemical network.
+
+ A "transformation" that performs no transformation at all.
+ Technically a self-loop, or an edge with the same ``ChemicalSystem`` at
+ either end.
+
+ Functionally used for applying a dynamics protocol to a ``ChemicalSystem``
+ that performs no alchemical transformation at all. This allows e.g.
+ equilibrium MD to be performed on a ``ChemicalSystem`` as desired alongside
+ alchemical protocols between it and and other ``ChemicalSystem`` objects.
+
+ Parameters
+ ----------
+ system : ChemicalSystem
+ The (identical) end states of the "transformation" to be sampled
+ protocol : Protocol
+ The sampling method to use on the ``system``
+ name : str, optional
+ A human-readable name for this transformation.
+
+ """
self._system = system
self._protocol = protocol
@@ -245,21 +270,31 @@ def __repr__(self):
return f"{self.__class__.__name__}({content})"
@property
- def stateA(self):
+ def stateA(self) -> ChemicalSystem:
+ """The :class:`.ChemicalSystem` this "transformation" samples.
+
+ Synonomous with ``system`` attribute.
+
+ """
return self._system
@property
- def stateB(self):
+ def stateB(self) -> ChemicalSystem:
+ """The :class:`.ChemicalSystem` this "transformation" samples.
+
+ Synonomous with ``system`` attribute.
+
+ """
return self._system
@property
def system(self) -> ChemicalSystem:
+ """The :class:`.ChemicalSystem` this "transformation" samples."""
return self._system
@property
def protocol(self):
- """
- The protocol for sampling dynamics of the `ChemicalSystem`.
+ """The :class:`.Protocol` for sampling dynamics of the ``system``.
Includes all details needed to perform required
simulations/calculations.
@@ -282,6 +317,9 @@ def create(
"""
Returns a ``ProtocolDAG`` executing this ``NonTransformation.protocol``.
"""
+ # TODO: once we have an implicit component mapping concept, use this
+ # here instead of None to allow use of alchemical protocols with
+ # NonTransformations
return self.protocol.create(
stateA=self.system,
stateB=self.system,
diff --git a/gufe/utils.py b/gufe/utils.py
index f9d3b0ff..ea9ad861 100644
--- a/gufe/utils.py
+++ b/gufe/utils.py
@@ -22,13 +22,13 @@ class ensure_filelike:
the stream will always be closed. Filelike inputs will close
if this parameter is True.
"""
+
def __init__(self, fn, mode=None, force_close=False):
filelikes = (io.TextIOBase, io.RawIOBase, io.BufferedIOBase)
if isinstance(fn, filelikes):
if mode is not None:
warnings.warn(
- f"mode='{mode}' specified with {fn.__class__.__name__}."
- " User-specified mode will be ignored."
+ f"mode='{mode}' specified with {fn.__class__.__name__}." " User-specified mode will be ignored."
)
self.to_open = None
self.do_close = force_close
@@ -51,4 +51,3 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
if self.do_close:
self.context.close()
-
diff --git a/gufe/vendor/pdb_file/PdbxContainers.py b/gufe/vendor/pdb_file/PdbxContainers.py
index adedadc2..e8bcd94d 100644
--- a/gufe/vendor/pdb_file/PdbxContainers.py
+++ b/gufe/vendor/pdb_file/PdbxContainers.py
@@ -5,7 +5,7 @@
#
# Update:
# 23-Mar-2011 jdw Added method to rename attributes in category containers.
-# 05-Apr-2011 jdw Change cif writer to select double quoting as preferred
+# 05-Apr-2011 jdw Change cif writer to select double quoting as preferred
# quoting style where possible.
# 16-Jan-2012 jdw Create base class for DataCategory class
# 22-Mar-2012 jdw when append attributes to existing categories update
@@ -36,29 +36,31 @@
data and definition meta data.
"""
-from __future__ import absolute_import
__docformat__ = "restructuredtext en"
-__author__ = "John Westbrook"
-__email__ = "jwest@rcsb.rutgers.edu"
-__license__ = "Creative Commons Attribution 3.0 Unported"
-__version__ = "V0.01"
+__author__ = "John Westbrook"
+__email__ = "jwest@rcsb.rutgers.edu"
+__license__ = "Creative Commons Attribution 3.0 Unported"
+__version__ = "V0.01"
-import re,sys,traceback
+import re
+import sys
+import traceback
+
+
+class CifName:
+ """Class of utilities for CIF-style data names -"""
-class CifName(object):
- ''' Class of utilities for CIF-style data names -
- '''
def __init__(self):
pass
@staticmethod
def categoryPart(name):
- tname=""
+ tname = ""
if name.startswith("_"):
- tname=name[1:]
+ tname = name[1:]
else:
- tname=name
+ tname = name
i = tname.find(".")
if i == -1:
@@ -72,40 +74,40 @@ def attributePart(name):
if i == -1:
return None
else:
- return name[i+1:]
+ return name[i + 1 :]
-class ContainerBase(object):
- ''' Container base class for data and definition objects.
- '''
- def __init__(self,name):
+class ContainerBase:
+ """Container base class for data and definition objects."""
+
+ def __init__(self, name):
# The enclosing scope of the data container (e.g. data_/save_)
self.__name = name
- # List of category names within this container -
- self.__objNameList=[]
+ # List of category names within this container -
+ self.__objNameList = []
# dictionary of DataCategory objects keyed by category name.
- self.__objCatalog={}
- self.__type=None
+ self.__objCatalog = {}
+ self.__type = None
def getType(self):
return self.__type
- def setType(self,type):
- self.__type=type
-
+ def setType(self, type):
+ self.__type = type
+
def getName(self):
return self.__name
- def setName(self,name):
- self.__name=name
+ def setName(self, name):
+ self.__name = name
- def exists(self,name):
+ def exists(self, name):
if name in self.__objCatalog:
return True
else:
return False
-
- def getObj(self,name):
+
+ def getObj(self, name):
if name in self.__objCatalog:
return self.__objCatalog[name]
else:
@@ -113,55 +115,51 @@ def getObj(self,name):
def getObjNameList(self):
return self.__objNameList
-
- def append(self,obj):
- """ Add the input object to the current object catalog. An existing object
- of the same name will be overwritten.
+
+ def append(self, obj):
+ """Add the input object to the current object catalog. An existing object
+ of the same name will be overwritten.
"""
if obj.getName() is not None:
if obj.getName() not in self.__objCatalog:
- # self.__objNameList is keeping track of object order here --
+ # self.__objNameList is keeping track of object order here --
self.__objNameList.append(obj.getName())
- self.__objCatalog[obj.getName()]=obj
+ self.__objCatalog[obj.getName()] = obj
- def replace(self,obj):
- """ Replace an existing object with the input object
- """
- if ((obj.getName() is not None) and (obj.getName() in self.__objCatalog) ):
- self.__objCatalog[obj.getName()]=obj
-
+ def replace(self, obj):
+ """Replace an existing object with the input object"""
+ if (obj.getName() is not None) and (obj.getName() in self.__objCatalog):
+ self.__objCatalog[obj.getName()] = obj
- def printIt(self,fh=sys.stdout,type="brief"):
- fh.write("+ %s container: %30s contains %4d categories\n" %
- (self.getType(),self.getName(),len(self.__objNameList)))
+ def printIt(self, fh=sys.stdout, type="brief"):
+ fh.write(
+ "+ %s container: %30s contains %4d categories\n" % (self.getType(), self.getName(), len(self.__objNameList))
+ )
for nm in self.__objNameList:
fh.write("--------------------------------------------\n")
fh.write("Data category: %s\n" % nm)
- if type == 'brief':
+ if type == "brief":
self.__objCatalog[nm].printIt(fh)
else:
self.__objCatalog[nm].dumpIt(fh)
-
- def rename(self,curName,newName):
- """ Change the name of an object in place -
- """
+ def rename(self, curName, newName):
+ """Change the name of an object in place -"""
try:
- i=self.__objNameList.index(curName)
- self.__objNameList[i]=newName
- self.__objCatalog[newName]=self.__objCatalog[curName]
+ i = self.__objNameList.index(curName)
+ self.__objNameList[i] = newName
+ self.__objCatalog[newName] = self.__objCatalog[curName]
self.__objCatalog[newName].setName(newName)
return True
except:
return False
- def remove(self,curName):
- """ Revmove object by name. Return True on success or False otherwise.
- """
+ def remove(self, curName):
+ """Revmove object by name. Return True on success or False otherwise."""
try:
if curName in self.__objCatalog:
del self.__objCatalog[curName]
- i=self.__objNameList.index(curName)
+ i = self.__objNameList.index(curName)
del self.__objNameList[i]
return True
else:
@@ -171,171 +169,180 @@ def remove(self,curName):
return False
-
+
class DefinitionContainer(ContainerBase):
- def __init__(self,name):
- super(DefinitionContainer,self).__init__(name)
- self.setType('definition')
+ def __init__(self, name):
+ super().__init__(name)
+ self.setType("definition")
def isCategory(self):
- if self.exists('category'):
+ if self.exists("category"):
return True
return False
def isAttribute(self):
- if self.exists('item'):
+ if self.exists("item"):
return True
return False
-
- def printIt(self,fh=sys.stdout,type="brief"):
- fh.write("Definition container: %30s contains %4d categories\n" %
- (self.getName(),len(self.getObjNameList())))
+ def printIt(self, fh=sys.stdout, type="brief"):
+ fh.write("Definition container: %30s contains %4d categories\n" % (self.getName(), len(self.getObjNameList())))
if self.isCategory():
fh.write("Definition type: category\n")
elif self.isAttribute():
fh.write("Definition type: item\n")
else:
- fh.write("Definition type: undefined\n")
+ fh.write("Definition type: undefined\n")
for nm in self.getObjNameList():
fh.write("--------------------------------------------\n")
fh.write("Definition category: %s\n" % nm)
- if type == 'brief':
+ if type == "brief":
self.getObj(nm).printIt(fh)
else:
self.getObj(nm).dumpId(fh)
-
class DataContainer(ContainerBase):
- ''' Container class for DataCategory objects.
- '''
- def __init__(self,name):
- super(DataContainer,self).__init__(name)
- self.setType('data')
- self.__globalFlag=False
-
- def invokeDataBlockMethod(self,type,method,db):
+ """Container class for DataCategory objects."""
+
+ def __init__(self, name):
+ super().__init__(name)
+ self.setType("data")
+ self.__globalFlag = False
+
+ def invokeDataBlockMethod(self, type, method, db):
self.__currentRow = 1
exec(method.getInline())
def setGlobal(self):
- self.__globalFlag=True
+ self.__globalFlag = True
def getGlobal(self):
- return self.__globalFlag
+ return self.__globalFlag
+
-
+class DataCategoryBase:
+ """Base object definition for a data category -"""
-class DataCategoryBase(object):
- """ Base object definition for a data category -
- """
- def __init__(self,name,attributeNameList=None,rowList=None):
+ def __init__(self, name, attributeNameList=None, rowList=None):
self._name = name
#
if rowList is not None:
- self._rowList=rowList
+ self._rowList = rowList
else:
- self._rowList=[]
+ self._rowList = []
if attributeNameList is not None:
- self._attributeNameList=attributeNameList
+ self._attributeNameList = attributeNameList
else:
- self._attributeNameList=[]
+ self._attributeNameList = []
#
# Derived class data -
#
- self._catalog={}
- self._numAttributes=0
+ self._catalog = {}
+ self._numAttributes = 0
#
self.__setup()
def __setup(self):
self._numAttributes = len(self._attributeNameList)
- self._catalog={}
+ self._catalog = {}
for attributeName in self._attributeNameList:
- attributeNameLC = attributeName.lower()
+ attributeNameLC = attributeName.lower()
self._catalog[attributeNameLC] = attributeName
+
#
- def setRowList(self,rowList):
- self._rowList=rowList
+ def setRowList(self, rowList):
+ self._rowList = rowList
- def setAttributeNameList(self,attributeNameList):
- self._attributeNameList=attributeNameList
+ def setAttributeNameList(self, attributeNameList):
+ self._attributeNameList = attributeNameList
self.__setup()
- def setName(self,name):
- self._name=name
+ def setName(self, name):
+ self._name = name
def get(self):
- return (self._name,self._attributeNameList,self._rowList)
+ return (self._name, self._attributeNameList, self._rowList)
+
-
-
class DataCategory(DataCategoryBase):
- """ Methods for creating, accessing, and formatting PDBx cif data categories.
- """
- def __init__(self,name,attributeNameList=None,rowList=None):
- super(DataCategory,self).__init__(name,attributeNameList,rowList)
+ """Methods for creating, accessing, and formatting PDBx cif data categories."""
+
+ def __init__(self, name, attributeNameList=None, rowList=None):
+ super().__init__(name, attributeNameList, rowList)
#
self.__lfh = sys.stdout
-
- self.__currentRowIndex=0
- self.__currentAttribute=None
+
+ self.__currentRowIndex = 0
+ self.__currentAttribute = None
#
- self.__avoidEmbeddedQuoting=False
+ self.__avoidEmbeddedQuoting = False
#
# --------------------------------------------------------------------
- # any whitespace
- self.__wsRe=re.compile(r"\s")
- self.__wsAndQuotesRe=re.compile(r"[\s'\"]")
+ # any whitespace
+ self.__wsRe = re.compile(r"\s")
+ self.__wsAndQuotesRe = re.compile(r"[\s'\"]")
# any newline or carriage control
- self.__nlRe=re.compile(r"[\n\r]")
+ self.__nlRe = re.compile(r"[\n\r]")
#
- # single quote
- self.__sqRe=re.compile(r"[']")
+ # single quote
+ self.__sqRe = re.compile(r"[']")
#
- self.__sqWsRe=re.compile(r"('\s)|(\s')")
-
- # double quote
- self.__dqRe=re.compile(r'["]')
- self.__dqWsRe=re.compile(r'("\s)|(\s")')
+ self.__sqWsRe = re.compile(r"('\s)|(\s')")
+
+ # double quote
+ self.__dqRe = re.compile(r'["]')
+ self.__dqWsRe = re.compile(r'("\s)|(\s")')
#
- self.__intRe=re.compile(r'^[0-9]+$')
- self.__floatRe=re.compile(r'^-?(([0-9]+)[.]?|([0-9]*[.][0-9]+))([(][0-9]+[)])?([eE][+-]?[0-9]+)?$')
+ self.__intRe = re.compile(r"^[0-9]+$")
+ self.__floatRe = re.compile(r"^-?(([0-9]+)[.]?|([0-9]*[.][0-9]+))([(][0-9]+[)])?([eE][+-]?[0-9]+)?$")
#
- self.__dataTypeList=['DT_NULL_VALUE','DT_INTEGER','DT_FLOAT','DT_UNQUOTED_STRING','DT_ITEM_NAME',
- 'DT_DOUBLE_QUOTED_STRING','DT_SINGLE_QUOTED_STRING','DT_MULTI_LINE_STRING']
- self.__formatTypeList=['FT_NULL_VALUE','FT_NUMBER','FT_NUMBER','FT_UNQUOTED_STRING',
- 'FT_QUOTED_STRING','FT_QUOTED_STRING','FT_QUOTED_STRING','FT_MULTI_LINE_STRING']
+ self.__dataTypeList = [
+ "DT_NULL_VALUE",
+ "DT_INTEGER",
+ "DT_FLOAT",
+ "DT_UNQUOTED_STRING",
+ "DT_ITEM_NAME",
+ "DT_DOUBLE_QUOTED_STRING",
+ "DT_SINGLE_QUOTED_STRING",
+ "DT_MULTI_LINE_STRING",
+ ]
+ self.__formatTypeList = [
+ "FT_NULL_VALUE",
+ "FT_NUMBER",
+ "FT_NUMBER",
+ "FT_UNQUOTED_STRING",
+ "FT_QUOTED_STRING",
+ "FT_QUOTED_STRING",
+ "FT_QUOTED_STRING",
+ "FT_MULTI_LINE_STRING",
+ ]
#
-
def __getitem__(self, x):
- """ Implements list-type functionality -
- Implements op[x] for some special cases -
- x=integer - returns the row in category (normal list behavior)
- x=string - returns the value of attribute 'x' in first row.
+ """Implements list-type functionality -
+ Implements op[x] for some special cases -
+ x=integer - returns the row in category (normal list behavior)
+ x=string - returns the value of attribute 'x' in first row.
"""
if isinstance(x, int):
- #return self._rowList.__getitem__(x)
+ # return self._rowList.__getitem__(x)
return self._rowList[x]
elif isinstance(x, str):
try:
- #return self._rowList[0][x]
- ii=self.getAttributeIndex(x)
+ # return self._rowList[0][x]
+ ii = self.getAttributeIndex(x)
return self._rowList[0][ii]
except (IndexError, KeyError):
raise KeyError
raise TypeError(x)
-
-
+
def getCurrentAttribute(self):
return self.__currentAttribute
-
def getRowIndex(self):
return self.__currentRowIndex
@@ -343,20 +350,20 @@ def getRowList(self):
return self._rowList
def getRowCount(self):
- return (len(self._rowList))
+ return len(self._rowList)
- def getRow(self,index):
+ def getRow(self, index):
try:
return self._rowList[index]
except:
return []
- def removeRow(self,index):
+ def removeRow(self, index):
try:
- if ((index >= 0) and (index < len(self._rowList))):
- del self._rowList[index]
+ if (index >= 0) and (index < len(self._rowList)):
+ del self._rowList[index]
if self.__currentRowIndex >= len(self._rowList):
- self.__currentRowIndex = len(self._rowList) -1
+ self.__currentRowIndex = len(self._rowList) - 1
return True
else:
pass
@@ -365,20 +372,19 @@ def removeRow(self,index):
return False
- def getFullRow(self,index):
- """ Return a full row based on the length of the the attribute list.
- """
+ def getFullRow(self, index):
+ """Return a full row based on the length of the the attribute list."""
try:
- if (len(self._rowList[index]) < self._numAttributes):
- for ii in range( self._numAttributes-len(self._rowList[index])):
- self._rowList[index].append('?')
+ if len(self._rowList[index]) < self._numAttributes:
+ for ii in range(self._numAttributes - len(self._rowList[index])):
+ self._rowList[index].append("?")
return self._rowList[index]
except:
- return ['?' for ii in range(self._numAttributes)]
+ return ["?" for ii in range(self._numAttributes)]
def getName(self):
return self._name
-
+
def getAttributeList(self):
return self._attributeNameList
@@ -386,54 +392,53 @@ def getAttributeCount(self):
return len(self._attributeNameList)
def getAttributeListWithOrder(self):
- oL=[]
- for ii,att in enumerate(self._attributeNameList):
- oL.append((att,ii))
+ oL = []
+ for ii, att in enumerate(self._attributeNameList):
+ oL.append((att, ii))
return oL
- def getAttributeIndex(self,attributeName):
+ def getAttributeIndex(self, attributeName):
try:
return self._attributeNameList.index(attributeName)
except:
return -1
- def hasAttribute(self,attributeName):
+ def hasAttribute(self, attributeName):
return attributeName in self._attributeNameList
-
- def getIndex(self,attributeName):
+
+ def getIndex(self, attributeName):
try:
return self._attributeNameList.index(attributeName)
except:
return -1
def getItemNameList(self):
- itemNameList=[]
+ itemNameList = []
for att in self._attributeNameList:
- itemNameList.append("_"+self._name+"."+att)
+ itemNameList.append("_" + self._name + "." + att)
return itemNameList
-
- def append(self,row):
- #self.__lfh.write("PdbxContainer(append) category %s row %r\n" % (self._name,row))
+
+ def append(self, row):
+ # self.__lfh.write("PdbxContainer(append) category %s row %r\n" % (self._name,row))
self._rowList.append(row)
- def appendAttribute(self,attributeName):
- attributeNameLC = attributeName.lower()
- if attributeNameLC in self._catalog:
+ def appendAttribute(self, attributeName):
+ attributeNameLC = attributeName.lower()
+ if attributeNameLC in self._catalog:
i = self._attributeNameList.index(self._catalog[attributeNameLC])
self._attributeNameList[i] = attributeName
self._catalog[attributeNameLC] = attributeName
- #self.__lfh.write("Appending existing attribute %s\n" % attributeName)
+ # self.__lfh.write("Appending existing attribute %s\n" % attributeName)
else:
- #self.__lfh.write("Appending existing attribute %s\n" % attributeName)
+ # self.__lfh.write("Appending existing attribute %s\n" % attributeName)
self._attributeNameList.append(attributeName)
self._catalog[attributeNameLC] = attributeName
#
self._numAttributes = len(self._attributeNameList)
-
- def appendAttributeExtendRows(self,attributeName):
- attributeNameLC = attributeName.lower()
- if attributeNameLC in self._catalog:
+ def appendAttributeExtendRows(self, attributeName):
+ attributeNameLC = attributeName.lower()
+ if attributeNameLC in self._catalog:
i = self._attributeNameList.index(self._catalog[attributeNameLC])
self._attributeNameList[i] = attributeName
self._catalog[attributeNameLC] = attributeName
@@ -442,15 +447,13 @@ def appendAttributeExtendRows(self,attributeName):
self._attributeNameList.append(attributeName)
self._catalog[attributeNameLC] = attributeName
# add a placeholder to any existing rows for the new attribute.
- if (len(self._rowList) > 0):
+ if len(self._rowList) > 0:
for row in self._rowList:
row.append("?")
#
self._numAttributes = len(self._attributeNameList)
-
-
- def getValue(self,attributeName=None,rowIndex=None):
+ def getValue(self, attributeName=None, rowIndex=None):
if attributeName is None:
attribute = self.__currentAttribute
else:
@@ -458,363 +461,377 @@ def getValue(self,attributeName=None,rowIndex=None):
if rowIndex is None:
rowI = self.__currentRowIndex
else:
- rowI =rowIndex
-
- if isinstance(attribute, str) and isinstance(rowI,int):
+ rowI = rowIndex
+
+ if isinstance(attribute, str) and isinstance(rowI, int):
try:
return self._rowList[rowI][self._attributeNameList.index(attribute)]
- except (IndexError):
- raise IndexError
+ except IndexError:
+ raise IndexError
raise IndexError(attribute)
- def setValue(self,value,attributeName=None,rowIndex=None):
+ def setValue(self, value, attributeName=None, rowIndex=None):
if attributeName is None:
- attribute=self.__currentAttribute
+ attribute = self.__currentAttribute
else:
- attribute=attributeName
+ attribute = attributeName
if rowIndex is None:
rowI = self.__currentRowIndex
else:
rowI = rowIndex
- if isinstance(attribute, str) and isinstance(rowI,int):
+ if isinstance(attribute, str) and isinstance(rowI, int):
try:
# if row index is out of range - add the rows -
- for ii in range(rowI+1 - len(self._rowList)):
- self._rowList.append(self.__emptyRow())
+ for ii in range(rowI + 1 - len(self._rowList)):
+ self._rowList.append(self.__emptyRow())
# self._rowList[rowI][attribute]=value
- ll=len(self._rowList[rowI])
- ind=self._attributeNameList.index(attribute)
-
- # extend the list if needed -
- if ( ind >= ll):
- self._rowList[rowI].extend([None for ii in xrange(2*ind -ll)])
- self._rowList[rowI][ind]=value
- except (IndexError):
- self.__lfh.write("DataCategory(setvalue) index error category %s attribute %s index %d value %r\n" %
- (self._name,attribute,rowI,value))
- traceback.print_exc(file=self.__lfh)
- #raise IndexError
- except (ValueError):
- self.__lfh.write("DataCategory(setvalue) value error category %s attribute %s index %d value %r\n" %
- (self._name,attribute,rowI,value))
- traceback.print_exc(file=self.__lfh)
- #raise ValueError
+ ll = len(self._rowList[rowI])
+ ind = self._attributeNameList.index(attribute)
+
+ # extend the list if needed -
+ if ind >= ll:
+ self._rowList[rowI].extend([None for ii in xrange(2 * ind - ll)])
+ self._rowList[rowI][ind] = value
+ except IndexError:
+ self.__lfh.write(
+ "DataCategory(setvalue) index error category %s attribute %s index %d value %r\n"
+ % (self._name, attribute, rowI, value)
+ )
+ traceback.print_exc(file=self.__lfh)
+ # raise IndexError
+ except ValueError:
+ self.__lfh.write(
+ "DataCategory(setvalue) value error category %s attribute %s index %d value %r\n"
+ % (self._name, attribute, rowI, value)
+ )
+ traceback.print_exc(file=self.__lfh)
+ # raise ValueError
def __emptyRow(self):
return [None for ii in range(len(self._attributeNameList))]
-
- def replaceValue(self,oldValue,newValue,attributeName):
- numReplace=0
+
+ def replaceValue(self, oldValue, newValue, attributeName):
+ numReplace = 0
if attributeName not in self._attributeNameList:
return numReplace
- ind=self._attributeNameList.index(attributeName)
+ ind = self._attributeNameList.index(attributeName)
for row in self._rowList:
if row[ind] == oldValue:
- row[ind]=newValue
+ row[ind] = newValue
numReplace += 1
return numReplace
- def replaceSubstring(self,oldValue,newValue,attributeName):
- ok=False
- if attributeName not in self._attributeNameList:
+ def replaceSubstring(self, oldValue, newValue, attributeName):
+ ok = False
+ if attributeName not in self._attributeNameList:
return ok
- ind=self._attributeNameList.index(attributeName)
+ ind = self._attributeNameList.index(attributeName)
for row in self._rowList:
- val=row[ind]
- row[ind]=val.replace(oldValue,newValue)
+ val = row[ind]
+ row[ind] = val.replace(oldValue, newValue)
if val != row[ind]:
- ok=True
+ ok = True
return ok
-
- def invokeAttributeMethod(self,attributeName,type,method,db):
+
+ def invokeAttributeMethod(self, attributeName, type, method, db):
self.__currentRowIndex = 0
- self.__currentAttribute=attributeName
+ self.__currentAttribute = attributeName
self.appendAttribute(attributeName)
- currentRowIndex=self.__currentRowIndex
+ currentRowIndex = self.__currentRowIndex
#
- ind=self._attributeNameList.index(attributeName)
+ ind = self._attributeNameList.index(attributeName)
if len(self._rowList) == 0:
- row=[None for ii in xrange(len(self._attributeNameList)*2)]
- row[ind]=None
+ row = [None for ii in xrange(len(self._attributeNameList) * 2)]
+ row[ind] = None
self._rowList.append(row)
-
+
for row in self._rowList:
ll = len(row)
- if (ind >= ll):
- row.extend([None for ii in xrange(2*ind-ll)])
- row[ind]=None
+ if ind >= ll:
+ row.extend([None for ii in xrange(2 * ind - ll)])
+ row[ind] = None
exec(method.getInline())
- self.__currentRowIndex+=1
- currentRowIndex=self.__currentRowIndex
+ self.__currentRowIndex += 1
+ currentRowIndex = self.__currentRowIndex
- def invokeCategoryMethod(self,type,method,db):
+ def invokeCategoryMethod(self, type, method, db):
self.__currentRowIndex = 0
exec(method.getInline())
def getAttributeLengthMaximumList(self):
- mList=[0 for i in len(self._attributeNameList)]
+ mList = [0 for i in len(self._attributeNameList)]
for row in self._rowList:
- for indx,val in enumerate(row):
- mList[indx] = max(mList[indx],len(val))
+ for indx, val in enumerate(row):
+ mList[indx] = max(mList[indx], len(val))
return mList
-
- def renameAttribute(self,curAttributeName,newAttributeName):
- """ Change the name of an attribute in place -
- """
+
+ def renameAttribute(self, curAttributeName, newAttributeName):
+ """Change the name of an attribute in place -"""
try:
- i=self._attributeNameList.index(curAttributeName)
- self._attributeNameList[i]=newAttributeName
- del self._catalog[curAttributeName.lower()]
- self._catalog[newAttributeName.lower()]=newAttributeName
+ i = self._attributeNameList.index(curAttributeName)
+ self._attributeNameList[i] = newAttributeName
+ del self._catalog[curAttributeName.lower()]
+ self._catalog[newAttributeName.lower()] = newAttributeName
return True
except:
return False
-
- def printIt(self,fh=sys.stdout):
+
+ def printIt(self, fh=sys.stdout):
fh.write("--------------------------------------------\n")
- fh.write(" Category: %s attribute list length: %d\n" %
- (self._name,len(self._attributeNameList)))
+ fh.write(" Category: %s attribute list length: %d\n" % (self._name, len(self._attributeNameList)))
for at in self._attributeNameList:
- fh.write(" Category: %s attribute: %s\n" % (self._name,at))
-
+ fh.write(f" Category: {self._name} attribute: {at}\n")
+
fh.write(" Row value list length: %d\n" % len(self._rowList))
#
for row in self._rowList[:2]:
#
if len(row) == len(self._attributeNameList):
- for ii,v in enumerate(row):
- fh.write(" %30s: %s ...\n" % (self._attributeNameList[ii],str(v)[:30]))
+ for ii, v in enumerate(row):
+ fh.write(" %30s: %s ...\n" % (self._attributeNameList[ii], str(v)[:30]))
else:
- fh.write("+WARNING - %s data length %d attribute name length %s mismatched\n" %
- (self._name,len(row),len(self._attributeNameList)))
+ fh.write(
+ "+WARNING - %s data length %d attribute name length %s mismatched\n"
+ % (self._name, len(row), len(self._attributeNameList))
+ )
- def dumpIt(self,fh=sys.stdout):
+ def dumpIt(self, fh=sys.stdout):
fh.write("--------------------------------------------\n")
- fh.write(" Category: %s attribute list length: %d\n" %
- (self._name,len(self._attributeNameList)))
+ fh.write(" Category: %s attribute list length: %d\n" % (self._name, len(self._attributeNameList)))
for at in self._attributeNameList:
- fh.write(" Category: %s attribute: %s\n" % (self._name,at))
-
+ fh.write(f" Category: {self._name} attribute: {at}\n")
+
fh.write(" Value list length: %d\n" % len(self._rowList))
for row in self._rowList:
- for ii,v in enumerate(row):
- fh.write(" %30s: %s\n" % (self._attributeNameList[ii],v))
-
+ for ii, v in enumerate(row):
+ fh.write(" %30s: %s\n" % (self._attributeNameList[ii], v))
def __formatPdbx(self, inp):
- """ Format input data following PDBx quoting rules -
- """
+ """Format input data following PDBx quoting rules -"""
try:
- if (inp is None):
- return ("?",'DT_NULL_VALUE')
+ if inp is None:
+ return ("?", "DT_NULL_VALUE")
# pure numerical values are returned as unquoted strings
- if (isinstance(inp,int) or self.__intRe.search(str(inp))):
- return ( [str(inp)],'DT_INTEGER')
+ if isinstance(inp, int) or self.__intRe.search(str(inp)):
+ return ([str(inp)], "DT_INTEGER")
- if (isinstance(inp,float) or self.__floatRe.search(str(inp))):
- return ([str(inp)],'DT_FLOAT')
+ if isinstance(inp, float) or self.__floatRe.search(str(inp)):
+ return ([str(inp)], "DT_FLOAT")
# null value handling -
- if (inp == "." or inp == "?"):
- return ([inp],'DT_NULL_VALUE')
+ if inp == "." or inp == "?":
+ return ([inp], "DT_NULL_VALUE")
- if (inp == ""):
- return (["."],'DT_NULL_VALUE')
+ if inp == "":
+ return (["."], "DT_NULL_VALUE")
# Contains white space or quotes ?
if not self.__wsAndQuotesRe.search(inp):
if inp.startswith("_"):
- return (self.__doubleQuotedList(inp),'DT_ITEM_NAME')
+ return (self.__doubleQuotedList(inp), "DT_ITEM_NAME")
else:
- return ([str(inp)],'DT_UNQUOTED_STRING')
+ return ([str(inp)], "DT_UNQUOTED_STRING")
else:
if self.__nlRe.search(inp):
- return (self.__semiColonQuotedList(inp),'DT_MULTI_LINE_STRING')
+ return (self.__semiColonQuotedList(inp), "DT_MULTI_LINE_STRING")
else:
- if (self.__avoidEmbeddedQuoting):
+ if self.__avoidEmbeddedQuoting:
# change priority to choose double quoting where possible.
if not self.__dqRe.search(inp) and not self.__sqWsRe.search(inp):
- return (self.__doubleQuotedList(inp),'DT_DOUBLE_QUOTED_STRING')
+ return (
+ self.__doubleQuotedList(inp),
+ "DT_DOUBLE_QUOTED_STRING",
+ )
elif not self.__sqRe.search(inp) and not self.__dqWsRe.search(inp):
- return (self.__singleQuotedList(inp),'DT_SINGLE_QUOTED_STRING')
+ return (
+ self.__singleQuotedList(inp),
+ "DT_SINGLE_QUOTED_STRING",
+ )
else:
- return (self.__semiColonQuotedList(inp),'DT_MULTI_LINE_STRING')
+ return (
+ self.__semiColonQuotedList(inp),
+ "DT_MULTI_LINE_STRING",
+ )
else:
# change priority to choose double quoting where possible.
if not self.__dqRe.search(inp):
- return (self.__doubleQuotedList(inp),'DT_DOUBLE_QUOTED_STRING')
+ return (
+ self.__doubleQuotedList(inp),
+ "DT_DOUBLE_QUOTED_STRING",
+ )
elif not self.__sqRe.search(inp):
- return (self.__singleQuotedList(inp),'DT_SINGLE_QUOTED_STRING')
+ return (
+ self.__singleQuotedList(inp),
+ "DT_SINGLE_QUOTED_STRING",
+ )
else:
- return (self.__semiColonQuotedList(inp),'DT_MULTI_LINE_STRING')
-
-
+ return (
+ self.__semiColonQuotedList(inp),
+ "DT_MULTI_LINE_STRING",
+ )
+
except:
- traceback.print_exc(file=self.__lfh)
+ traceback.print_exc(file=self.__lfh)
def __dataTypePdbx(self, inp):
- """ Detect the PDBx data type -
- """
- if (inp is None):
- return ('DT_NULL_VALUE')
-
+ """Detect the PDBx data type -"""
+ if inp is None:
+ return "DT_NULL_VALUE"
+
# pure numerical values are returned as unquoted strings
- if isinstance(inp,int) or self.__intRe.search(str(inp)):
- return ('DT_INTEGER')
+ if isinstance(inp, int) or self.__intRe.search(str(inp)):
+ return "DT_INTEGER"
- if isinstance(inp,float) or self.__floatRe.search(str(inp)):
- return ('DT_FLOAT')
+ if isinstance(inp, float) or self.__floatRe.search(str(inp)):
+ return "DT_FLOAT"
# null value handling -
- if (inp == "." or inp == "?"):
- return ('DT_NULL_VALUE')
+ if inp == "." or inp == "?":
+ return "DT_NULL_VALUE"
- if (inp == ""):
- return ('DT_NULL_VALUE')
+ if inp == "":
+ return "DT_NULL_VALUE"
# Contains white space or quotes ?
if not self.__wsAndQuotesRe.search(inp):
if inp.startswith("_"):
- return ('DT_ITEM_NAME')
+ return "DT_ITEM_NAME"
else:
- return ('DT_UNQUOTED_STRING')
+ return "DT_UNQUOTED_STRING"
else:
if self.__nlRe.search(inp):
- return ('DT_MULTI_LINE_STRING')
+ return "DT_MULTI_LINE_STRING"
else:
- if (self.__avoidEmbeddedQuoting):
+ if self.__avoidEmbeddedQuoting:
if not self.__sqRe.search(inp) and not self.__dqWsRe.search(inp):
- return ('DT_DOUBLE_QUOTED_STRING')
+ return "DT_DOUBLE_QUOTED_STRING"
elif not self.__dqRe.search(inp) and not self.__sqWsRe.search(inp):
- return ('DT_SINGLE_QUOTED_STRING')
+ return "DT_SINGLE_QUOTED_STRING"
else:
- return ('DT_MULTI_LINE_STRING')
+ return "DT_MULTI_LINE_STRING"
else:
if not self.__sqRe.search(inp):
- return ('DT_DOUBLE_QUOTED_STRING')
+ return "DT_DOUBLE_QUOTED_STRING"
elif not self.__dqRe.search(inp):
- return ('DT_SINGLE_QUOTED_STRING')
+ return "DT_SINGLE_QUOTED_STRING"
else:
- return ('DT_MULTI_LINE_STRING')
+ return "DT_MULTI_LINE_STRING"
- def __singleQuotedList(self,inp):
- l=[]
+ def __singleQuotedList(self, inp):
+ l = []
l.append("'")
l.append(inp)
- l.append("'")
- return(l)
+ l.append("'")
+ return l
- def __doubleQuotedList(self,inp):
- l=[]
+ def __doubleQuotedList(self, inp):
+ l = []
l.append('"')
l.append(inp)
- l.append('"')
- return(l)
-
- def __semiColonQuotedList(self,inp):
- l=[]
- l.append("\n")
- if inp[-1] == '\n':
+ l.append('"')
+ return l
+
+ def __semiColonQuotedList(self, inp):
+ l = []
+ l.append("\n")
+ if inp[-1] == "\n":
l.append(";")
l.append(inp)
l.append(";")
- l.append("\n")
+ l.append("\n")
else:
l.append(";")
l.append(inp)
- l.append("\n")
+ l.append("\n")
l.append(";")
- l.append("\n")
+ l.append("\n")
- return(l)
+ return l
- def getValueFormatted(self,attributeName=None,rowIndex=None):
+ def getValueFormatted(self, attributeName=None, rowIndex=None):
if attributeName is None:
- attribute=self.__currentAttribute
+ attribute = self.__currentAttribute
else:
- attribute=attributeName
+ attribute = attributeName
if rowIndex is None:
rowI = self.__currentRowIndex
else:
rowI = rowIndex
-
- if isinstance(attribute, str) and isinstance(rowI,int):
+
+ if isinstance(attribute, str) and isinstance(rowI, int):
try:
- list,type=self.__formatPdbx(self._rowList[rowI][self._attributeNameList.index(attribute)])
+ list, type = self.__formatPdbx(self._rowList[rowI][self._attributeNameList.index(attribute)])
return "".join(list)
- except (IndexError):
- self.__lfh.write("attributeName %s rowI %r rowdata %r\n" % (attributeName,rowI,self._rowList[rowI]))
- raise IndexError
+ except IndexError:
+ self.__lfh.write(f"attributeName {attributeName} rowI {rowI!r} rowdata {self._rowList[rowI]!r}\n")
+ raise IndexError
raise TypeError(attribute)
-
- def getValueFormattedByIndex(self,attributeIndex,rowIndex):
+ def getValueFormattedByIndex(self, attributeIndex, rowIndex):
try:
- list,type=self.__formatPdbx(self._rowList[rowIndex][attributeIndex])
+ list, type = self.__formatPdbx(self._rowList[rowIndex][attributeIndex])
return "".join(list)
- except (IndexError):
- raise IndexError
+ except IndexError:
+ raise IndexError
- def getAttributeValueMaxLengthList(self,steps=1):
- mList=[0 for i in range(len(self._attributeNameList))]
+ def getAttributeValueMaxLengthList(self, steps=1):
+ mList = [0 for i in range(len(self._attributeNameList))]
for row in self._rowList[::steps]:
for indx in range(len(self._attributeNameList)):
- val=row[indx]
- mList[indx] = max(mList[indx],len(str(val)))
+ val = row[indx]
+ mList[indx] = max(mList[indx], len(str(val)))
return mList
- def getFormatTypeList(self,steps=1):
+ def getFormatTypeList(self, steps=1):
try:
- curDataTypeList=['DT_NULL_VALUE' for i in range(len(self._attributeNameList))]
+ curDataTypeList = ["DT_NULL_VALUE" for i in range(len(self._attributeNameList))]
for row in self._rowList[::steps]:
for indx in range(len(self._attributeNameList)):
- val=row[indx]
+ val = row[indx]
# print "index ",indx," val ",val
- dType=self.__dataTypePdbx(val)
- dIndx=self.__dataTypeList.index(dType)
+ dType = self.__dataTypePdbx(val)
+ dIndx = self.__dataTypeList.index(dType)
# print "d type", dType, " d type index ",dIndx
-
- cType=curDataTypeList[indx]
- cIndx=self.__dataTypeList.index(cType)
- cIndx= max(cIndx,dIndx)
- curDataTypeList[indx]=self.__dataTypeList[cIndx]
+
+ cType = curDataTypeList[indx]
+ cIndx = self.__dataTypeList.index(cType)
+ cIndx = max(cIndx, dIndx)
+ curDataTypeList[indx] = self.__dataTypeList[cIndx]
# Map the format types to the data types
- curFormatTypeList=[]
+ curFormatTypeList = []
for dt in curDataTypeList:
- ii=self.__dataTypeList.index(dt)
+ ii = self.__dataTypeList.index(dt)
curFormatTypeList.append(self.__formatTypeList[ii])
except:
- self.__lfh.write("PdbxDataCategory(getFormatTypeList) ++Index error at index %d in row %r\n" % (indx,row))
+ self.__lfh.write("PdbxDataCategory(getFormatTypeList) ++Index error at index %d in row %r\n" % (indx, row))
- return curFormatTypeList,curDataTypeList
+ return curFormatTypeList, curDataTypeList
def getFormatTypeListX(self):
- curDataTypeList=['DT_NULL_VALUE' for i in range(len(self._attributeNameList))]
+ curDataTypeList = ["DT_NULL_VALUE" for i in range(len(self._attributeNameList))]
for row in self._rowList:
for indx in range(len(self._attributeNameList)):
- val=row[indx]
- #print "index ",indx," val ",val
- dType=self.__dataTypePdbx(val)
- dIndx=self.__dataTypeList.index(dType)
- #print "d type", dType, " d type index ",dIndx
-
- cType=curDataTypeList[indx]
- cIndx=self.__dataTypeList.index(cType)
- cIndx= max(cIndx,dIndx)
- curDataTypeList[indx]=self.__dataTypeList[cIndx]
+ val = row[indx]
+ # print "index ",indx," val ",val
+ dType = self.__dataTypePdbx(val)
+ dIndx = self.__dataTypeList.index(dType)
+ # print "d type", dType, " d type index ",dIndx
+
+ cType = curDataTypeList[indx]
+ cIndx = self.__dataTypeList.index(cType)
+ cIndx = max(cIndx, dIndx)
+ curDataTypeList[indx] = self.__dataTypeList[cIndx]
# Map the format types to the data types
- curFormatTypeList=[]
+ curFormatTypeList = []
for dt in curDataTypeList:
- ii=self.__dataTypeList.index(dt)
+ ii = self.__dataTypeList.index(dt)
curFormatTypeList.append(self.__formatTypeList[ii])
- return curFormatTypeList,curDataTypeList
-
-
+ return curFormatTypeList, curDataTypeList
diff --git a/gufe/vendor/pdb_file/PdbxReader.py b/gufe/vendor/pdb_file/PdbxReader.py
index 8b49dfc9..96f161e0 100644
--- a/gufe/vendor/pdb_file/PdbxReader.py
+++ b/gufe/vendor/pdb_file/PdbxReader.py
@@ -16,25 +16,27 @@
The tokenizer used in this module is modeled after the clever parser design
used in the PyMMLIB package.
-
+
PyMMLib Development Group
Authors: Ethan Merritt: merritt@u.washington.ed & Jay Painter: jay.painter@gmail.com
See: http://pymmlib.sourceforge.net/
"""
-from __future__ import absolute_import
import re
+
from .PdbxContainers import *
+
class PdbxError(Exception):
- """ Class for catch general errors
- """
+ """Class for catch general errors"""
+
pass
+
class SyntaxError(Exception):
- """ Class for catching syntax errors
- """
+ """Class for catching syntax errors"""
+
def __init__(self, lineNumber, text):
Exception.__init__(self)
self.lineNumber = lineNumber
@@ -44,27 +46,26 @@ def __str__(self):
return "%%ERROR - [at line: %d] %s" % (self.lineNumber, self.text)
-
-class PdbxReader(object):
- """ PDBx reader for data files and dictionaries.
-
- """
- def __init__(self,ifh):
- """ ifh - input file handle returned by open()
- """
- #
- self.__curLineNumber = 0
- self.__ifh=ifh
- self.__stateDict={"data": "ST_DATA_CONTAINER",
- "loop": "ST_TABLE",
- "global": "ST_GLOBAL_CONTAINER",
- "save": "ST_DEFINITION",
- "stop": "ST_STOP"}
-
+class PdbxReader:
+ """PDBx reader for data files and dictionaries."""
+
+ def __init__(self, ifh):
+ """ifh - input file handle returned by open()"""
+ #
+ self.__curLineNumber = 0
+ self.__ifh = ifh
+ self.__stateDict = {
+ "data": "ST_DATA_CONTAINER",
+ "loop": "ST_TABLE",
+ "global": "ST_GLOBAL_CONTAINER",
+ "save": "ST_DEFINITION",
+ "stop": "ST_STOP",
+ }
+
def read(self, containerList):
"""
Appends to the input list of definition and data containers.
-
+
"""
self.__curLineNumber = 0
try:
@@ -72,7 +73,7 @@ def read(self, containerList):
except StopIteration:
pass
except RuntimeError as err:
- if 'StopIteration' not in str(err):
+ if "StopIteration" not in str(err):
raise
else:
raise PdbxError()
@@ -80,52 +81,51 @@ def read(self, containerList):
def __syntaxError(self, errText):
raise SyntaxError(self.__curLineNumber, errText)
- def __getContainerName(self,inWord):
- """ Returns the name of the data_ or save_ container
- """
+ def __getContainerName(self, inWord):
+ """Returns the name of the data_ or save_ container"""
return str(inWord[5:]).strip()
-
+
def __getState(self, inWord):
- """Identifies reserved syntax elements and assigns an associated state.
+ """Identifies reserved syntax elements and assigns an associated state.
- Returns: (reserved word, state)
- where -
- reserved word - is one of CIF syntax elements:
- data_, loop_, global_, save_, stop_
- state - the parser state required to process this next section.
+ Returns: (reserved word, state)
+ where -
+ reserved word - is one of CIF syntax elements:
+ data_, loop_, global_, save_, stop_
+ state - the parser state required to process this next section.
"""
i = inWord.find("_")
if i == -1:
- return None,"ST_UNKNOWN"
+ return None, "ST_UNKNOWN"
try:
- rWord=inWord[:i].lower()
+ rWord = inWord[:i].lower()
return rWord, self.__stateDict[rWord]
except:
- return None,"ST_UNKNOWN"
-
+ return None, "ST_UNKNOWN"
+
def __parser(self, tokenizer, containerList):
- """ Parser for PDBx data files and dictionaries.
+ """Parser for PDBx data files and dictionaries.
- Input - tokenizer() reentrant method recognizing data item names (_category.attribute)
- quoted strings (single, double and multi-line semi-colon delimited), and unquoted
- strings.
+ Input - tokenizer() reentrant method recognizing data item names (_category.attribute)
+ quoted strings (single, double and multi-line semi-colon delimited), and unquoted
+ strings.
- containerList - list-type container for data and definition objects parsed from
- from the input file.
+ containerList - list-type container for data and definition objects parsed from
+ from the input file.
- Return:
- containerList - is appended with data and definition objects -
+ Return:
+ containerList - is appended with data and definition objects -
"""
# Working container - data or definition
curContainer = None
#
- # Working category container
+ # Working category container
categoryIndex = {}
curCategory = None
#
curRow = None
- state = None
+ state = None
# Find the first reserved word and begin capturing data.
#
@@ -133,16 +133,16 @@ def __parser(self, tokenizer, containerList):
curCatName, curAttName, curQuotedString, curWord = next(tokenizer)
if curWord is None:
continue
- reservedWord, state = self.__getState(curWord)
+ reservedWord, state = self.__getState(curWord)
if reservedWord is not None:
break
-
+
while True:
#
# Set the current state -
#
# At this point in the processing cycle we are expecting a token containing
- # either a '_category.attribute' or a reserved word.
+ # either a '_category.attribute' or a reserved word.
#
if curCatName is not None:
state = "ST_KEY_VALUE_PAIR"
@@ -150,16 +150,16 @@ def __parser(self, tokenizer, containerList):
reservedWord, state = self.__getState(curWord)
else:
self.__syntaxError("Miscellaneous syntax error")
- return
+ return
#
- # Process _category.attribute value assignments
+ # Process _category.attribute value assignments
#
if state == "ST_KEY_VALUE_PAIR":
try:
curCategory = categoryIndex[curCatName]
except KeyError:
- # A new category is encountered - create a container and add a row
+ # A new category is encountered - create a container and add a row
curCategory = categoryIndex[curCatName] = DataCategory(curCatName)
try:
@@ -168,12 +168,12 @@ def __parser(self, tokenizer, containerList):
self.__syntaxError("Category cannot be added to data_ block")
return
- curRow = []
+ curRow = []
curCategory.append(curRow)
else:
# Recover the existing row from the category
try:
- curRow = curCategory[0]
+ curRow = curCategory[0]
except IndexError:
self.__syntaxError("Internal index error accessing category data")
return
@@ -185,18 +185,17 @@ def __parser(self, tokenizer, containerList):
else:
curCategory.appendAttribute(curAttName)
-
# Get the data for this attribute from the next token
tCat, tAtt, curQuotedString, curWord = next(tokenizer)
if tCat is not None or (curQuotedString is None and curWord is None):
- self.__syntaxError("Missing data for item _%s.%s" % (curCatName,curAttName))
+ self.__syntaxError(f"Missing data for item _{curCatName}.{curAttName}")
if curWord is not None:
- #
- # Validation check token for misplaced reserved words -
#
- reservedWord, state = self.__getState(curWord)
+ # Validation check token for misplaced reserved words -
+ #
+ reservedWord, state = self.__getState(curWord)
if reservedWord is not None:
self.__syntaxError("Unexpected reserved word: %s" % (reservedWord))
@@ -218,7 +217,7 @@ def __parser(self, tokenizer, containerList):
# The category name in the next curCatName,curAttName pair
# defines the name of the category container.
- curCatName,curAttName,curQuotedString,curWord = next(tokenizer)
+ curCatName, curAttName, curQuotedString, curWord = next(tokenizer)
if curCatName is None or curAttName is None:
self.__syntaxError("Unexpected token in loop_ declaration")
@@ -239,10 +238,10 @@ def __parser(self, tokenizer, containerList):
curCategory.appendAttribute(curAttName)
- # Read the rest of the loop_ declaration
+ # Read the rest of the loop_ declaration
while True:
curCatName, curAttName, curQuotedString, curWord = next(tokenizer)
-
+
if curCatName is None:
break
@@ -252,19 +251,18 @@ def __parser(self, tokenizer, containerList):
curCategory.appendAttribute(curAttName)
-
- # If the next token is a 'word', check it for any reserved words -
+ # If the next token is a 'word', check it for any reserved words -
if curWord is not None:
- reservedWord, state = self.__getState(curWord)
+ reservedWord, state = self.__getState(curWord)
if reservedWord is not None:
if reservedWord == "stop":
return
else:
self.__syntaxError("Unexpected reserved word after loop declaration: %s" % (reservedWord))
-
- # Read the table of data for this loop_ -
+
+ # Read the table of data for this loop_ -
while True:
- curRow = []
+ curRow = []
curCategory.append(curRow)
for tAtt in curCategory.getAttributeList():
@@ -273,9 +271,9 @@ def __parser(self, tokenizer, containerList):
elif curQuotedString is not None:
curRow.append(curQuotedString)
- curCatName,curAttName,curQuotedString,curWord = next(tokenizer)
+ curCatName, curAttName, curQuotedString, curWord = next(tokenizer)
- # loop_ data processing ends if -
+ # loop_ data processing ends if -
# A new _category.attribute is encountered
if curCatName is not None:
@@ -286,31 +284,30 @@ def __parser(self, tokenizer, containerList):
reservedWord, state = self.__getState(curWord)
if reservedWord is not None:
break
-
- continue
+ continue
elif state == "ST_DEFINITION":
# Ignore trailing unnamed saveframe delimiters e.g. 'save_'
- sName=self.__getContainerName(curWord)
- if (len(sName) > 0):
+ sName = self.__getContainerName(curWord)
+ if len(sName) > 0:
curContainer = DefinitionContainer(sName)
containerList.append(curContainer)
categoryIndex = {}
curCategory = None
- curCatName,curAttName,curQuotedString,curWord = next(tokenizer)
+ curCatName, curAttName, curQuotedString, curWord = next(tokenizer)
elif state == "ST_DATA_CONTAINER":
#
- dName=self.__getContainerName(curWord)
+ dName = self.__getContainerName(curWord)
if len(dName) == 0:
- dName="unidentified"
+ dName = "unidentified"
curContainer = DataContainer(dName)
containerList.append(curContainer)
categoryIndex = {}
curCategory = None
- curCatName,curAttName,curQuotedString,curWord = next(tokenizer)
+ curCatName, curAttName, curQuotedString, curWord = next(tokenizer)
elif state == "ST_STOP":
return
@@ -320,22 +317,21 @@ def __parser(self, tokenizer, containerList):
containerList.append(curContainer)
categoryIndex = {}
curCategory = None
- curCatName,curAttName,curQuotedString,curWord = next(tokenizer)
+ curCatName, curAttName, curQuotedString, curWord = next(tokenizer)
elif state == "ST_UNKNOWN":
self.__syntaxError("Unrecogized syntax element: " + str(curWord))
return
-
def __tokenizer(self, ifh):
- """ Tokenizer method for the mmCIF syntax file -
+ """Tokenizer method for the mmCIF syntax file -
- Each return/yield from this method returns information about
- the next token in the form of a tuple with the following structure.
+ Each return/yield from this method returns information about
+ the next token in the form of a tuple with the following structure.
- (category name, attribute name, quoted strings, words w/o quotes or white space)
+ (category name, attribute name, quoted strings, words w/o quotes or white space)
- Differentiated the reqular expression to the better handle embedded quotes.
+ Differentiated the reqular expression to the better handle embedded quotes.
"""
#
@@ -343,17 +339,17 @@ def __tokenizer(self, ifh):
# outside of this regex.
mmcifRe = re.compile(
r"(?:"
-
- "(?:_(.+?)[.](\S+))" "|" # _category.attribute
-
- "(?:['](.*?)(?:[']\s|[']$))" "|" # single quoted strings
- "(?:[\"](.*?)(?:[\"]\s|[\"]$))" "|" # double quoted strings
-
- "(?:\s*#.*$)" "|" # comments (dumped)
-
- "(\S+)" # unquoted words
-
- ")")
+ r"(?:_(.+?)[.](\S+))"
+ "|" # _category.attribute
+ r"(?:['](.*?)(?:[']\s|[']$))"
+ "|" # single quoted strings
+ r'(?:["](.*?)(?:["]\s|["]$))'
+ "|" # double quoted strings
+ r"(?:\s*#.*$)"
+ "|" # comments (dumped)
+ r"(\S+)" # unquoted words
+ ")"
+ )
fileIter = iter(ifh)
@@ -365,7 +361,7 @@ def __tokenizer(self, ifh):
# Dump comments
if line.startswith("#"):
continue
-
+
# Gobble up the entire semi-colon/multi-line delimited string and
# and stuff this into the string slot in the return tuple
#
@@ -385,7 +381,7 @@ def __tokenizer(self, ifh):
#
# Need to process the remainder of the current line -
line = line[1:]
- #continue
+ # continue
# Apply regex to the current line consolidate the single/double
# quoted within the quoted string category
@@ -398,16 +394,16 @@ def __tokenizer(self, ifh):
qs = tgroups[3]
else:
qs = None
- groups = (tgroups[0],tgroups[1],qs,tgroups[4])
+ groups = (tgroups[0], tgroups[1], qs, tgroups[4])
yield groups
def __tokenizerOrg(self, ifh):
- """ Tokenizer method for the mmCIF syntax file -
+ """Tokenizer method for the mmCIF syntax file -
- Each return/yield from this method returns information about
- the next token in the form of a tuple with the following structure.
+ Each return/yield from this method returns information about
+ the next token in the form of a tuple with the following structure.
- (category name, attribute name, quoted strings, words w/o quotes or white space)
+ (category name, attribute name, quoted strings, words w/o quotes or white space)
"""
#
@@ -415,16 +411,15 @@ def __tokenizerOrg(self, ifh):
# outside of this regex.
mmcifRe = re.compile(
r"(?:"
-
- "(?:_(.+?)[.](\S+))" "|" # _category.attribute
-
- "(?:['\"](.*?)(?:['\"]\s|['\"]$))" "|" # quoted strings
-
- "(?:\s*#.*$)" "|" # comments (dumped)
-
- "(\S+)" # unquoted words
-
- ")")
+ r"(?:_(.+?)[.](\S+))"
+ "|" # _category.attribute
+ "(?:['\"](.*?)(?:['\"]\\s|['\"]$))"
+ "|" # quoted strings
+ r"(?:\s*#.*$)"
+ "|" # comments (dumped)
+ r"(\S+)" # unquoted words
+ ")"
+ )
fileIter = iter(ifh)
@@ -436,7 +431,7 @@ def __tokenizerOrg(self, ifh):
# Dump comments
if line.startswith("#"):
continue
-
+
# Gobble up the entire semi-colon/multi-line delimited string and
# and stuff this into the string slot in the return tuple
#
@@ -456,9 +451,9 @@ def __tokenizerOrg(self, ifh):
#
# Need to process the remainder of the current line -
line = line[1:]
- #continue
+ # continue
- ## Apply regex to the current line
+ ## Apply regex to the current line
for it in mmcifRe.finditer(line):
groups = it.groups()
if groups != (None, None, None, None):
diff --git a/gufe/vendor/pdb_file/data/residues.xml b/gufe/vendor/pdb_file/data/residues.xml
index 7b6cfdd1..2a353a33 100644
--- a/gufe/vendor/pdb_file/data/residues.xml
+++ b/gufe/vendor/pdb_file/data/residues.xml
@@ -899,4 +899,4 @@
-
\ No newline at end of file
+
diff --git a/gufe/vendor/pdb_file/element.py b/gufe/vendor/pdb_file/element.py
index 6d00067c..3b675233 100644
--- a/gufe/vendor/pdb_file/element.py
+++ b/gufe/vendor/pdb_file/element.py
@@ -28,21 +28,18 @@
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
-from __future__ import absolute_import
+
__author__ = "Christopher M. Bruns"
__version__ = "1.0"
+import copyreg
import sys
from collections import OrderedDict
-from openmm.unit import daltons, is_quantity
-if sys.version_info >= (3, 0):
- import copyreg
-else:
- import copy_reg as copyreg
+from openmm.unit import daltons, is_quantity
-class Element(object):
+class Element:
"""An Element represents a chemical element.
The openmm.app.element module contains objects for all the standard chemical elements,
@@ -84,7 +81,7 @@ def __init__(self, number, name, symbol, mass):
Element._elements_by_mass = None
if s in Element._elements_by_symbol:
- raise ValueError('Duplicate element symbol %s' % s)
+ raise ValueError("Duplicate element symbol %s" % s)
Element._elements_by_symbol[s] = self
if number in Element._elements_by_atomic_number:
other_element = Element._elements_by_atomic_number[number]
@@ -128,13 +125,12 @@ def getByMass(mass):
if is_quantity(mass):
mass = mass.value_in_unit(daltons)
if mass < 0:
- raise ValueError('Invalid Higgs field')
+ raise ValueError("Invalid Higgs field")
# If this is our first time calling getByMass (or we added an element
# since the last call), re-generate the ordered by-mass dict cache
if Element._elements_by_mass is None:
Element._elements_by_mass = OrderedDict()
- for elem in sorted(Element._elements_by_symbol.values(),
- key=lambda x: x.mass):
+ for elem in sorted(Element._elements_by_symbol.values(), key=lambda x: x.mass):
Element._elements_by_mass[elem.mass.value_in_unit(daltons)] = elem
diff = mass
@@ -170,151 +166,151 @@ def mass(self):
return self._mass
def __str__(self):
- return '' % self.name
+ return "" % self.name
def __repr__(self):
- return '' % self.name
+ return "" % self.name
+
# This is for backward compatibility.
def get_by_symbol(symbol):
- """ Get the element with a particular chemical symbol. """
+ """Get the element with a particular chemical symbol."""
s = symbol.strip().upper()
return Element._elements_by_symbol[s]
+
def _pickle_element(element):
return (get_by_symbol, (element.symbol,))
+
copyreg.pickle(Element, _pickle_element)
# NOTE: getElementByMass assumes all masses are Quantity instances with unit
# "daltons". All elements need to obey this assumption, or that method will
# fail. No checking is done in getElementByMass for performance reasons
-hydrogen = Element( 1, "hydrogen", "H", 1.007947*daltons)
-deuterium = Element( 1, "deuterium", "D", 2.01355321270*daltons)
-helium = Element( 2, "helium", "He", 4.003*daltons)
-lithium = Element( 3, "lithium", "Li", 6.9412*daltons)
-beryllium = Element( 4, "beryllium", "Be", 9.0121823*daltons)
-boron = Element( 5, "boron", "B", 10.8117*daltons)
-carbon = Element( 6, "carbon", "C", 12.01078*daltons)
-nitrogen = Element( 7, "nitrogen", "N", 14.00672*daltons)
-oxygen = Element( 8, "oxygen", "O", 15.99943*daltons)
-fluorine = Element( 9, "fluorine", "F", 18.99840325*daltons)
-neon = Element( 10, "neon", "Ne", 20.17976*daltons)
-sodium = Element( 11, "sodium", "Na", 22.989769282*daltons)
-magnesium = Element( 12, "magnesium", "Mg", 24.30506*daltons)
-aluminum = Element( 13, "aluminum", "Al", 26.98153868*daltons)
-silicon = Element( 14, "silicon", "Si", 28.08553*daltons)
-phosphorus = Element( 15, "phosphorus", "P", 30.9737622*daltons)
-sulfur = Element( 16, "sulfur", "S", 32.0655*daltons)
-chlorine = Element( 17, "chlorine", "Cl", 35.4532*daltons)
-argon = Element( 18, "argon", "Ar", 39.9481*daltons)
-potassium = Element( 19, "potassium", "K", 39.09831*daltons)
-calcium = Element( 20, "calcium", "Ca", 40.0784*daltons)
-scandium = Element( 21, "scandium", "Sc", 44.9559126*daltons)
-titanium = Element( 22, "titanium", "Ti", 47.8671*daltons)
-vanadium = Element( 23, "vanadium", "V", 50.94151*daltons)
-chromium = Element( 24, "chromium", "Cr", 51.99616*daltons)
-manganese = Element( 25, "manganese", "Mn", 54.9380455*daltons)
-iron = Element( 26, "iron", "Fe", 55.8452*daltons)
-cobalt = Element( 27, "cobalt", "Co", 58.9331955*daltons)
-nickel = Element( 28, "nickel", "Ni", 58.69342*daltons)
-copper = Element( 29, "copper", "Cu", 63.5463*daltons)
-zinc = Element( 30, "zinc", "Zn", 65.4094*daltons)
-gallium = Element( 31, "gallium", "Ga", 69.7231*daltons)
-germanium = Element( 32, "germanium", "Ge", 72.641*daltons)
-arsenic = Element( 33, "arsenic", "As", 74.921602*daltons)
-selenium = Element( 34, "selenium", "Se", 78.963*daltons)
-bromine = Element( 35, "bromine", "Br", 79.9041*daltons)
-krypton = Element( 36, "krypton", "Kr", 83.7982*daltons)
-rubidium = Element( 37, "rubidium", "Rb", 85.46783*daltons)
-strontium = Element( 38, "strontium", "Sr", 87.621*daltons)
-yttrium = Element( 39, "yttrium", "Y", 88.905852*daltons)
-zirconium = Element( 40, "zirconium", "Zr", 91.2242*daltons)
-niobium = Element( 41, "niobium", "Nb", 92.906382*daltons)
-molybdenum = Element( 42, "molybdenum", "Mo", 95.942*daltons)
-technetium = Element( 43, "technetium", "Tc", 98*daltons)
-ruthenium = Element( 44, "ruthenium", "Ru", 101.072*daltons)
-rhodium = Element( 45, "rhodium", "Rh", 102.905502*daltons)
-palladium = Element( 46, "palladium", "Pd", 106.421*daltons)
-silver = Element( 47, "silver", "Ag", 107.86822*daltons)
-cadmium = Element( 48, "cadmium", "Cd", 112.4118*daltons)
-indium = Element( 49, "indium", "In", 114.8183*daltons)
-tin = Element( 50, "tin", "Sn", 118.7107*daltons)
-antimony = Element( 51, "antimony", "Sb", 121.7601*daltons)
-tellurium = Element( 52, "tellurium", "Te", 127.603*daltons)
-iodine = Element( 53, "iodine", "I", 126.904473*daltons)
-xenon = Element( 54, "xenon", "Xe", 131.2936*daltons)
-cesium = Element( 55, "cesium", "Cs", 132.90545192*daltons)
-barium = Element( 56, "barium", "Ba", 137.3277*daltons)
-lanthanum = Element( 57, "lanthanum", "La", 138.905477*daltons)
-cerium = Element( 58, "cerium", "Ce", 140.1161*daltons)
-praseodymium = Element( 59, "praseodymium", "Pr", 140.907652*daltons)
-neodymium = Element( 60, "neodymium", "Nd", 144.2423*daltons)
-promethium = Element( 61, "promethium", "Pm", 145*daltons)
-samarium = Element( 62, "samarium", "Sm", 150.362*daltons)
-europium = Element( 63, "europium", "Eu", 151.9641*daltons)
-gadolinium = Element( 64, "gadolinium", "Gd", 157.253*daltons)
-terbium = Element( 65, "terbium", "Tb", 158.925352*daltons)
-dysprosium = Element( 66, "dysprosium", "Dy", 162.5001*daltons)
-holmium = Element( 67, "holmium", "Ho", 164.930322*daltons)
-erbium = Element( 68, "erbium", "Er", 167.2593*daltons)
-thulium = Element( 69, "thulium", "Tm", 168.934212*daltons)
-ytterbium = Element( 70, "ytterbium", "Yb", 173.043*daltons)
-lutetium = Element( 71, "lutetium", "Lu", 174.9671*daltons)
-hafnium = Element( 72, "hafnium", "Hf", 178.492*daltons)
-tantalum = Element( 73, "tantalum", "Ta", 180.947882*daltons)
-tungsten = Element( 74, "tungsten", "W", 183.841*daltons)
-rhenium = Element( 75, "rhenium", "Re", 186.2071*daltons)
-osmium = Element( 76, "osmium", "Os", 190.233*daltons)
-iridium = Element( 77, "iridium", "Ir", 192.2173*daltons)
-platinum = Element( 78, "platinum", "Pt", 195.0849*daltons)
-gold = Element( 79, "gold", "Au", 196.9665694*daltons)
-mercury = Element( 80, "mercury", "Hg", 200.592*daltons)
-thallium = Element( 81, "thallium", "Tl", 204.38332*daltons)
-lead = Element( 82, "lead", "Pb", 207.21*daltons)
-bismuth = Element( 83, "bismuth", "Bi", 208.980401*daltons)
-polonium = Element( 84, "polonium", "Po", 209*daltons)
-astatine = Element( 85, "astatine", "At", 210*daltons)
-radon = Element( 86, "radon", "Rn", 222.018*daltons)
-francium = Element( 87, "francium", "Fr", 223*daltons)
-radium = Element( 88, "radium", "Ra", 226*daltons)
-actinium = Element( 89, "actinium", "Ac", 227*daltons)
-thorium = Element( 90, "thorium", "Th", 232.038062*daltons)
-protactinium = Element( 91, "protactinium", "Pa", 231.035882*daltons)
-uranium = Element( 92, "uranium", "U", 238.028913*daltons)
-neptunium = Element( 93, "neptunium", "Np", 237*daltons)
-plutonium = Element( 94, "plutonium", "Pu", 244*daltons)
-americium = Element( 95, "americium", "Am", 243*daltons)
-curium = Element( 96, "curium", "Cm", 247*daltons)
-berkelium = Element( 97, "berkelium", "Bk", 247*daltons)
-californium = Element( 98, "californium", "Cf", 251*daltons)
-einsteinium = Element( 99, "einsteinium", "Es", 252*daltons)
-fermium = Element(100, "fermium", "Fm", 257*daltons)
-mendelevium = Element(101, "mendelevium", "Md", 258*daltons)
-nobelium = Element(102, "nobelium", "No", 259*daltons)
-lawrencium = Element(103, "lawrencium", "Lr", 262*daltons)
-rutherfordium = Element(104, "rutherfordium", "Rf", 261*daltons)
-dubnium = Element(105, "dubnium", "Db", 262*daltons)
-seaborgium = Element(106, "seaborgium", "Sg", 266*daltons)
-bohrium = Element(107, "bohrium", "Bh", 264*daltons)
-hassium = Element(108, "hassium", "Hs", 269*daltons)
-meitnerium = Element(109, "meitnerium", "Mt", 268*daltons)
-darmstadtium = Element(110, "darmstadtium", "Ds", 281*daltons)
-roentgenium = Element(111, "roentgenium", "Rg", 272*daltons)
-ununbium = Element(112, "ununbium", "Uub", 285*daltons)
-ununtrium = Element(113, "ununtrium", "Uut", 284*daltons)
-ununquadium = Element(114, "ununquadium", "Uuq", 289*daltons)
-ununpentium = Element(115, "ununpentium", "Uup", 288*daltons)
-ununhexium = Element(116, "ununhexium", "Uuh", 292*daltons)
+hydrogen = Element(1, "hydrogen", "H", 1.007947 * daltons)
+deuterium = Element(1, "deuterium", "D", 2.01355321270 * daltons)
+helium = Element(2, "helium", "He", 4.003 * daltons)
+lithium = Element(3, "lithium", "Li", 6.9412 * daltons)
+beryllium = Element(4, "beryllium", "Be", 9.0121823 * daltons)
+boron = Element(5, "boron", "B", 10.8117 * daltons)
+carbon = Element(6, "carbon", "C", 12.01078 * daltons)
+nitrogen = Element(7, "nitrogen", "N", 14.00672 * daltons)
+oxygen = Element(8, "oxygen", "O", 15.99943 * daltons)
+fluorine = Element(9, "fluorine", "F", 18.99840325 * daltons)
+neon = Element(10, "neon", "Ne", 20.17976 * daltons)
+sodium = Element(11, "sodium", "Na", 22.989769282 * daltons)
+magnesium = Element(12, "magnesium", "Mg", 24.30506 * daltons)
+aluminum = Element(13, "aluminum", "Al", 26.98153868 * daltons)
+silicon = Element(14, "silicon", "Si", 28.08553 * daltons)
+phosphorus = Element(15, "phosphorus", "P", 30.9737622 * daltons)
+sulfur = Element(16, "sulfur", "S", 32.0655 * daltons)
+chlorine = Element(17, "chlorine", "Cl", 35.4532 * daltons)
+argon = Element(18, "argon", "Ar", 39.9481 * daltons)
+potassium = Element(19, "potassium", "K", 39.09831 * daltons)
+calcium = Element(20, "calcium", "Ca", 40.0784 * daltons)
+scandium = Element(21, "scandium", "Sc", 44.9559126 * daltons)
+titanium = Element(22, "titanium", "Ti", 47.8671 * daltons)
+vanadium = Element(23, "vanadium", "V", 50.94151 * daltons)
+chromium = Element(24, "chromium", "Cr", 51.99616 * daltons)
+manganese = Element(25, "manganese", "Mn", 54.9380455 * daltons)
+iron = Element(26, "iron", "Fe", 55.8452 * daltons)
+cobalt = Element(27, "cobalt", "Co", 58.9331955 * daltons)
+nickel = Element(28, "nickel", "Ni", 58.69342 * daltons)
+copper = Element(29, "copper", "Cu", 63.5463 * daltons)
+zinc = Element(30, "zinc", "Zn", 65.4094 * daltons)
+gallium = Element(31, "gallium", "Ga", 69.7231 * daltons)
+germanium = Element(32, "germanium", "Ge", 72.641 * daltons)
+arsenic = Element(33, "arsenic", "As", 74.921602 * daltons)
+selenium = Element(34, "selenium", "Se", 78.963 * daltons)
+bromine = Element(35, "bromine", "Br", 79.9041 * daltons)
+krypton = Element(36, "krypton", "Kr", 83.7982 * daltons)
+rubidium = Element(37, "rubidium", "Rb", 85.46783 * daltons)
+strontium = Element(38, "strontium", "Sr", 87.621 * daltons)
+yttrium = Element(39, "yttrium", "Y", 88.905852 * daltons)
+zirconium = Element(40, "zirconium", "Zr", 91.2242 * daltons)
+niobium = Element(41, "niobium", "Nb", 92.906382 * daltons)
+molybdenum = Element(42, "molybdenum", "Mo", 95.942 * daltons)
+technetium = Element(43, "technetium", "Tc", 98 * daltons)
+ruthenium = Element(44, "ruthenium", "Ru", 101.072 * daltons)
+rhodium = Element(45, "rhodium", "Rh", 102.905502 * daltons)
+palladium = Element(46, "palladium", "Pd", 106.421 * daltons)
+silver = Element(47, "silver", "Ag", 107.86822 * daltons)
+cadmium = Element(48, "cadmium", "Cd", 112.4118 * daltons)
+indium = Element(49, "indium", "In", 114.8183 * daltons)
+tin = Element(50, "tin", "Sn", 118.7107 * daltons)
+antimony = Element(51, "antimony", "Sb", 121.7601 * daltons)
+tellurium = Element(52, "tellurium", "Te", 127.603 * daltons)
+iodine = Element(53, "iodine", "I", 126.904473 * daltons)
+xenon = Element(54, "xenon", "Xe", 131.2936 * daltons)
+cesium = Element(55, "cesium", "Cs", 132.90545192 * daltons)
+barium = Element(56, "barium", "Ba", 137.3277 * daltons)
+lanthanum = Element(57, "lanthanum", "La", 138.905477 * daltons)
+cerium = Element(58, "cerium", "Ce", 140.1161 * daltons)
+praseodymium = Element(59, "praseodymium", "Pr", 140.907652 * daltons)
+neodymium = Element(60, "neodymium", "Nd", 144.2423 * daltons)
+promethium = Element(61, "promethium", "Pm", 145 * daltons)
+samarium = Element(62, "samarium", "Sm", 150.362 * daltons)
+europium = Element(63, "europium", "Eu", 151.9641 * daltons)
+gadolinium = Element(64, "gadolinium", "Gd", 157.253 * daltons)
+terbium = Element(65, "terbium", "Tb", 158.925352 * daltons)
+dysprosium = Element(66, "dysprosium", "Dy", 162.5001 * daltons)
+holmium = Element(67, "holmium", "Ho", 164.930322 * daltons)
+erbium = Element(68, "erbium", "Er", 167.2593 * daltons)
+thulium = Element(69, "thulium", "Tm", 168.934212 * daltons)
+ytterbium = Element(70, "ytterbium", "Yb", 173.043 * daltons)
+lutetium = Element(71, "lutetium", "Lu", 174.9671 * daltons)
+hafnium = Element(72, "hafnium", "Hf", 178.492 * daltons)
+tantalum = Element(73, "tantalum", "Ta", 180.947882 * daltons)
+tungsten = Element(74, "tungsten", "W", 183.841 * daltons)
+rhenium = Element(75, "rhenium", "Re", 186.2071 * daltons)
+osmium = Element(76, "osmium", "Os", 190.233 * daltons)
+iridium = Element(77, "iridium", "Ir", 192.2173 * daltons)
+platinum = Element(78, "platinum", "Pt", 195.0849 * daltons)
+gold = Element(79, "gold", "Au", 196.9665694 * daltons)
+mercury = Element(80, "mercury", "Hg", 200.592 * daltons)
+thallium = Element(81, "thallium", "Tl", 204.38332 * daltons)
+lead = Element(82, "lead", "Pb", 207.21 * daltons)
+bismuth = Element(83, "bismuth", "Bi", 208.980401 * daltons)
+polonium = Element(84, "polonium", "Po", 209 * daltons)
+astatine = Element(85, "astatine", "At", 210 * daltons)
+radon = Element(86, "radon", "Rn", 222.018 * daltons)
+francium = Element(87, "francium", "Fr", 223 * daltons)
+radium = Element(88, "radium", "Ra", 226 * daltons)
+actinium = Element(89, "actinium", "Ac", 227 * daltons)
+thorium = Element(90, "thorium", "Th", 232.038062 * daltons)
+protactinium = Element(91, "protactinium", "Pa", 231.035882 * daltons)
+uranium = Element(92, "uranium", "U", 238.028913 * daltons)
+neptunium = Element(93, "neptunium", "Np", 237 * daltons)
+plutonium = Element(94, "plutonium", "Pu", 244 * daltons)
+americium = Element(95, "americium", "Am", 243 * daltons)
+curium = Element(96, "curium", "Cm", 247 * daltons)
+berkelium = Element(97, "berkelium", "Bk", 247 * daltons)
+californium = Element(98, "californium", "Cf", 251 * daltons)
+einsteinium = Element(99, "einsteinium", "Es", 252 * daltons)
+fermium = Element(100, "fermium", "Fm", 257 * daltons)
+mendelevium = Element(101, "mendelevium", "Md", 258 * daltons)
+nobelium = Element(102, "nobelium", "No", 259 * daltons)
+lawrencium = Element(103, "lawrencium", "Lr", 262 * daltons)
+rutherfordium = Element(104, "rutherfordium", "Rf", 261 * daltons)
+dubnium = Element(105, "dubnium", "Db", 262 * daltons)
+seaborgium = Element(106, "seaborgium", "Sg", 266 * daltons)
+bohrium = Element(107, "bohrium", "Bh", 264 * daltons)
+hassium = Element(108, "hassium", "Hs", 269 * daltons)
+meitnerium = Element(109, "meitnerium", "Mt", 268 * daltons)
+darmstadtium = Element(110, "darmstadtium", "Ds", 281 * daltons)
+roentgenium = Element(111, "roentgenium", "Rg", 272 * daltons)
+ununbium = Element(112, "ununbium", "Uub", 285 * daltons)
+ununtrium = Element(113, "ununtrium", "Uut", 284 * daltons)
+ununquadium = Element(114, "ununquadium", "Uuq", 289 * daltons)
+ununpentium = Element(115, "ununpentium", "Uup", 288 * daltons)
+ununhexium = Element(116, "ununhexium", "Uuh", 292 * daltons)
# Aliases to recognize common alternative spellings. Both the '==' and 'is'
# relational operators will work with any chosen name
sulphur = sulfur
aluminium = aluminum
-if sys.version_info >= (3, 0):
- def _iteritems(dict):
- return dict.items()
-else:
- def _iteritems(dict):
- return dict.iteritems()
+
+def _iteritems(dict):
+ return dict.items()
diff --git a/gufe/vendor/pdb_file/pdbfile.py b/gufe/vendor/pdb_file/pdbfile.py
index df769055..9eb26111 100644
--- a/gufe/vendor/pdb_file/pdbfile.py
+++ b/gufe/vendor/pdb_file/pdbfile.py
@@ -28,40 +28,71 @@
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
-from __future__ import print_function, division, absolute_import
+
__author__ = "Peter Eastman" # was true but now riesben vendored this! HarrHarr
__version__ = "1.0"
+import math
import os
import sys
-import math
-import numpy as np
import xml.etree.ElementTree as etree
from copy import copy
from datetime import date
+import numpy as np
+from openmm.unit import Quantity, angstroms, is_quantity, nanometers, norm
+
+from . import element as elem
+from .pdbstructure import PdbStructure
from .topology import Topology
from .unitcell import computeLengthsAndAngles
-from .pdbstructure import PdbStructure
-from . import element as elem
-from openmm.unit import nanometers, angstroms, is_quantity, norm, Quantity
-
-class PDBFile(object):
+class PDBFile:
"""PDBFile parses a Protein Data Bank (PDB) file and constructs a Topology and a set of atom positions from it.
This class also provides methods for creating PDB files. To write a file containing a single model, call
writeFile(). You also can create files that contain multiple models. To do this, first call writeHeader(),
- then writeModel() once for each model in the file, and finally writeFooter() to complete the file."""
+ then writeModel() once for each model in the file, and finally writeFooter() to complete the file.
+ """
_residueNameReplacements = {}
_atomNameReplacements = {}
- _standardResidues = ['ALA', 'ASN', 'CYS', 'GLU', 'HIS', 'LEU', 'MET', 'PRO', 'THR', 'TYR',
- 'ARG', 'ASP', 'GLN', 'GLY', 'ILE', 'LYS', 'PHE', 'SER', 'TRP', 'VAL',
- 'A', 'G', 'C', 'U', 'I', 'DA', 'DG', 'DC', 'DT', 'DI', 'HOH']
-
- def __init__(self, file, extraParticleIdentifier='EP'):
+ _standardResidues = [
+ "ALA",
+ "ASN",
+ "CYS",
+ "GLU",
+ "HIS",
+ "LEU",
+ "MET",
+ "PRO",
+ "THR",
+ "TYR",
+ "ARG",
+ "ASP",
+ "GLN",
+ "GLY",
+ "ILE",
+ "LYS",
+ "PHE",
+ "SER",
+ "TRP",
+ "VAL",
+ "A",
+ "G",
+ "C",
+ "U",
+ "I",
+ "DA",
+ "DG",
+ "DC",
+ "DT",
+ "DI",
+ "HOH",
+ ]
+
+ def __init__(self, file, extraParticleIdentifier="EP"):
"""Load a PDB file.
The atom positions and Topology can be retrieved by calling getPositions() and getTopology().
@@ -73,10 +104,46 @@ def __init__(self, file, extraParticleIdentifier='EP'):
extraParticleIdentifier : string='EP'
if this value appears in the element column for an ATOM record, the Atom's element will be set to None to mark it as an extra particle
"""
-
- metalElements = ['Al','As','Ba','Ca','Cd','Ce','Co','Cs','Cu','Dy','Fe','Gd','Hg','Ho','In','Ir','K','Li','Mg',
- 'Mn','Mo','Na','Ni','Pb','Pd','Pt','Rb','Rh','Sm','Sr','Te','Tl','V','W','Yb','Zn']
-
+
+ metalElements = [
+ "Al",
+ "As",
+ "Ba",
+ "Ca",
+ "Cd",
+ "Ce",
+ "Co",
+ "Cs",
+ "Cu",
+ "Dy",
+ "Fe",
+ "Gd",
+ "Hg",
+ "Ho",
+ "In",
+ "Ir",
+ "K",
+ "Li",
+ "Mg",
+ "Mn",
+ "Mo",
+ "Na",
+ "Ni",
+ "Pb",
+ "Pd",
+ "Pt",
+ "Rb",
+ "Rh",
+ "Sm",
+ "Sr",
+ "Te",
+ "Tl",
+ "V",
+ "W",
+ "Yb",
+ "Zn",
+ ]
+
top = Topology()
## The Topology read from the PDB file
self.topology = top
@@ -91,7 +158,11 @@ def __init__(self, file, extraParticleIdentifier='EP'):
if isinstance(file, str):
inputfile = open(file)
own_handle = True
- pdb = PdbStructure(inputfile, load_all_models=True, extraParticleIdentifier=extraParticleIdentifier)
+ pdb = PdbStructure(
+ inputfile,
+ load_all_models=True,
+ extraParticleIdentifier=extraParticleIdentifier,
+ )
if own_handle:
inputfile.close()
PDBFile._loadNameReplacementTables()
@@ -120,7 +191,7 @@ def __init__(self, file, extraParticleIdentifier='EP'):
atomName = atomReplacements[atomName]
atomName = atomName.strip()
element = atom.element
- if element == 'EP':
+ if element == "EP":
element = None
elif element is None:
# Try to guess the element.
@@ -128,24 +199,24 @@ def __init__(self, file, extraParticleIdentifier='EP'):
upper = atomName.upper()
while len(upper) > 1 and upper[0].isdigit():
upper = upper[1:]
- if upper.startswith('CL'):
+ if upper.startswith("CL"):
element = elem.chlorine
- elif upper.startswith('NA'):
+ elif upper.startswith("NA"):
element = elem.sodium
- elif upper.startswith('MG'):
+ elif upper.startswith("MG"):
element = elem.magnesium
- elif upper.startswith('BE'):
+ elif upper.startswith("BE"):
element = elem.beryllium
- elif upper.startswith('LI'):
+ elif upper.startswith("LI"):
element = elem.lithium
- elif upper.startswith('K'):
+ elif upper.startswith("K"):
element = elem.potassium
- elif upper.startswith('ZN'):
+ elif upper.startswith("ZN"):
element = elem.zinc
- elif len(residue) == 1 and upper.startswith('CA'):
+ elif len(residue) == 1 and upper.startswith("CA"):
element = elem.calcium
- elif upper.startswith('D') and any(a.name == atomName[1:] for a in residue.iter_atoms()):
- pass # A Drude particle
+ elif upper.startswith("D") and any(a.name == atomName[1:] for a in residue.iter_atoms()):
+ pass # A Drude particle
else:
try:
element = elem.get_by_symbol(upper[0])
@@ -165,7 +236,7 @@ def __init__(self, file, extraParticleIdentifier='EP'):
processedAtomNames.add(atom.get_name())
pos = atom.get_position().value_in_unit(nanometers)
coords.append(np.array([pos[0], pos[1], pos[2]]))
- self._positions.append(coords*nanometers)
+ self._positions.append(coords * nanometers)
## The atom positions read from the PDB file. If the file contains multiple frames, these are the positions in the first frame.
self.positions = self._positions[0]
self.topology.setPeriodicBoxVectors(pdb.get_periodic_box_vectors())
@@ -179,16 +250,25 @@ def __init__(self, file, extraParticleIdentifier='EP'):
for connect in pdb.models[-1].connects:
i = connect[0]
for j in connect[1:]:
- if i in atomByNumber and j in atomByNumber:
+ if i in atomByNumber and j in atomByNumber:
if atomByNumber[i].element is not None and atomByNumber[j].element is not None:
- if atomByNumber[i].element.symbol not in metalElements and atomByNumber[j].element.symbol not in metalElements:
- connectBonds.append((atomByNumber[i], atomByNumber[j]))
- elif atomByNumber[i].element.symbol in metalElements and atomByNumber[j].residue.name not in PDBFile._standardResidues:
- connectBonds.append((atomByNumber[i], atomByNumber[j]))
- elif atomByNumber[j].element.symbol in metalElements and atomByNumber[i].residue.name not in PDBFile._standardResidues:
- connectBonds.append((atomByNumber[i], atomByNumber[j]))
+ if (
+ atomByNumber[i].element.symbol not in metalElements
+ and atomByNumber[j].element.symbol not in metalElements
+ ):
+ connectBonds.append((atomByNumber[i], atomByNumber[j]))
+ elif (
+ atomByNumber[i].element.symbol in metalElements
+ and atomByNumber[j].residue.name not in PDBFile._standardResidues
+ ):
+ connectBonds.append((atomByNumber[i], atomByNumber[j]))
+ elif (
+ atomByNumber[j].element.symbol in metalElements
+ and atomByNumber[i].residue.name not in PDBFile._standardResidues
+ ):
+ connectBonds.append((atomByNumber[i], atomByNumber[j]))
else:
- connectBonds.append((atomByNumber[i], atomByNumber[j]))
+ connectBonds.append((atomByNumber[i], atomByNumber[j]))
if len(connectBonds) > 0:
# Only add bonds that don't already exist.
existingBonds = set(top.bonds())
@@ -218,9 +298,12 @@ def getPositions(self, asNumpy=False, frame=0):
"""
if asNumpy:
if self._numpyPositions is None:
- self._numpyPositions = [None]*len(self._positions)
+ self._numpyPositions = [None] * len(self._positions)
if self._numpyPositions[frame] is None:
- self._numpyPositions[frame] = Quantity(np.array(self._positions[frame].value_in_unit(nanometers)), nanometers)
+ self._numpyPositions[frame] = Quantity(
+ np.array(self._positions[frame].value_in_unit(nanometers)),
+ nanometers,
+ )
return self._numpyPositions[frame]
return self._positions[frame]
@@ -228,31 +311,31 @@ def getPositions(self, asNumpy=False, frame=0):
def _loadNameReplacementTables():
"""Load the list of atom and residue name replacements."""
if len(PDBFile._residueNameReplacements) == 0:
- tree = etree.parse(os.path.join(os.path.dirname(__file__), 'data', 'pdbNames.xml'))
+ tree = etree.parse(os.path.join(os.path.dirname(__file__), "data", "pdbNames.xml"))
allResidues = {}
proteinResidues = {}
nucleicAcidResidues = {}
- for residue in tree.getroot().findall('Residue'):
- name = residue.attrib['name']
- if name == 'All':
+ for residue in tree.getroot().findall("Residue"):
+ name = residue.attrib["name"]
+ if name == "All":
PDBFile._parseResidueAtoms(residue, allResidues)
- elif name == 'Protein':
+ elif name == "Protein":
PDBFile._parseResidueAtoms(residue, proteinResidues)
- elif name == 'Nucleic':
+ elif name == "Nucleic":
PDBFile._parseResidueAtoms(residue, nucleicAcidResidues)
for atom in allResidues:
proteinResidues[atom] = allResidues[atom]
nucleicAcidResidues[atom] = allResidues[atom]
- for residue in tree.getroot().findall('Residue'):
- name = residue.attrib['name']
+ for residue in tree.getroot().findall("Residue"):
+ name = residue.attrib["name"]
for id in residue.attrib:
- if id == 'name' or id.startswith('alt'):
+ if id == "name" or id.startswith("alt"):
PDBFile._residueNameReplacements[residue.attrib[id]] = name
- if 'type' not in residue.attrib:
+ if "type" not in residue.attrib:
atoms = copy(allResidues)
- elif residue.attrib['type'] == 'Protein':
+ elif residue.attrib["type"] == "Protein":
atoms = copy(proteinResidues)
- elif residue.attrib['type'] == 'Nucleic':
+ elif residue.attrib["type"] == "Nucleic":
atoms = copy(nucleicAcidResidues)
else:
atoms = copy(allResidues)
@@ -261,13 +344,19 @@ def _loadNameReplacementTables():
@staticmethod
def _parseResidueAtoms(residue, map):
- for atom in residue.findall('Atom'):
- name = atom.attrib['name']
+ for atom in residue.findall("Atom"):
+ name = atom.attrib["name"]
for id in atom.attrib:
map[atom.attrib[id]] = name
@staticmethod
- def writeFile(topology, positions, file=sys.stdout, keepIds=False, extraParticleIdentifier='EP'):
+ def writeFile(
+ topology,
+ positions,
+ file=sys.stdout,
+ keepIds=False,
+ extraParticleIdentifier="EP",
+ ):
"""Write a PDB file containing a single model.
Parameters
@@ -287,7 +376,13 @@ def writeFile(topology, positions, file=sys.stdout, keepIds=False, extraParticle
String to write in the element column of the ATOM records for atoms whose element is None (extra particles)
"""
PDBFile.writeHeader(topology, file)
- PDBFile.writeModel(topology, positions, file, keepIds=keepIds, extraParticleIdentifier=extraParticleIdentifier)
+ PDBFile.writeModel(
+ topology,
+ positions,
+ file,
+ keepIds=keepIds,
+ extraParticleIdentifier=extraParticleIdentifier,
+ )
PDBFile.writeFooter(topology, file)
@staticmethod
@@ -305,12 +400,29 @@ def writeHeader(topology, file=sys.stdout):
vectors = topology.getPeriodicBoxVectors()
if vectors is not None:
a, b, c, alpha, beta, gamma = computeLengthsAndAngles(vectors)
- RAD_TO_DEG = 180/math.pi
- print("CRYST1%9.3f%9.3f%9.3f%7.2f%7.2f%7.2f P 1 1 " % (
- a*10, b*10, c*10, alpha*RAD_TO_DEG, beta*RAD_TO_DEG, gamma*RAD_TO_DEG), file=file)
+ RAD_TO_DEG = 180 / math.pi
+ print(
+ "CRYST1%9.3f%9.3f%9.3f%7.2f%7.2f%7.2f P 1 1 "
+ % (
+ a * 10,
+ b * 10,
+ c * 10,
+ alpha * RAD_TO_DEG,
+ beta * RAD_TO_DEG,
+ gamma * RAD_TO_DEG,
+ ),
+ file=file,
+ )
@staticmethod
- def writeModel(topology, positions, file=sys.stdout, modelIndex=None, keepIds=False, extraParticleIdentifier='EP'):
+ def writeModel(
+ topology,
+ positions,
+ file=sys.stdout,
+ modelIndex=None,
+ keepIds=False,
+ extraParticleIdentifier="EP",
+ ):
"""Write out a model to a PDB file.
Parameters
@@ -335,26 +447,30 @@ def writeModel(topology, positions, file=sys.stdout, modelIndex=None, keepIds=Fa
"""
if len(list(topology.atoms())) != len(positions):
- raise ValueError('The number of positions must match the number of atoms')
+ raise ValueError("The number of positions must match the number of atoms")
if is_quantity(positions):
positions = positions.value_in_unit(angstroms)
if any(math.isnan(norm(pos)) for pos in positions):
- raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
+ raise ValueError(
+ "Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan"
+ )
if any(math.isinf(norm(pos)) for pos in positions):
- raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
+ raise ValueError(
+ "Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan"
+ )
nonHeterogens = PDBFile._standardResidues[:]
- nonHeterogens.remove('HOH')
+ nonHeterogens.remove("HOH")
atomIndex = 1
posIndex = 0
if modelIndex is not None:
print("MODEL %4d" % modelIndex, file=file)
- for (chainIndex, chain) in enumerate(topology.chains()):
+ for chainIndex, chain in enumerate(topology.chains()):
if keepIds and len(chain.id) == 1:
chainName = chain.id
else:
- chainName = chr(ord('A')+chainIndex%26)
+ chainName = chr(ord("A") + chainIndex % 26)
residues = list(chain.residues())
- for (resIndex, res) in enumerate(residues):
+ for resIndex, res in enumerate(residues):
if len(res.name) > 3:
resName = res.name[:3]
else:
@@ -362,7 +478,7 @@ def writeModel(topology, positions, file=sys.stdout, modelIndex=None, keepIds=Fa
if keepIds and len(res.id) < 5:
resId = res.id
else:
- resId = _formatIndex(resIndex+1, 4)
+ resId = _formatIndex(resIndex + 1, 4)
if len(res.insertionCode) == 1:
resIC = res.insertionCode
else:
@@ -377,22 +493,35 @@ def writeModel(topology, positions, file=sys.stdout, modelIndex=None, keepIds=Fa
else:
symbol = extraParticleIdentifier
if len(atom.name) < 4 and atom.name[:1].isalpha() and len(symbol) < 2:
- atomName = ' '+atom.name
+ atomName = " " + atom.name
elif len(atom.name) > 4:
atomName = atom.name[:4]
else:
atomName = atom.name
coords = positions[posIndex]
line = "%s%5s %-4s %3s %s%4s%1s %s%s%s 1.00 0.00 %2s " % (
- recordName, _formatIndex(atomIndex, 5), atomName, resName, chainName, resId, resIC, _format_83(coords[0]),
- _format_83(coords[1]), _format_83(coords[2]), symbol)
+ recordName,
+ _formatIndex(atomIndex, 5),
+ atomName,
+ resName,
+ chainName,
+ resId,
+ resIC,
+ _format_83(coords[0]),
+ _format_83(coords[1]),
+ _format_83(coords[2]),
+ symbol,
+ )
if len(line) != 80:
- raise ValueError('Fixed width overflow detected')
+ raise ValueError("Fixed width overflow detected")
print(line, file=file)
posIndex += 1
atomIndex += 1
- if resIndex == len(residues)-1:
- print("TER %5s %3s %s%4s" % (_formatIndex(atomIndex, 5), resName, chainName, resId), file=file)
+ if resIndex == len(residues) - 1:
+ print(
+ "TER %5s %3s %s%4s" % (_formatIndex(atomIndex, 5), resName, chainName, resId),
+ file=file,
+ )
atomIndex += 1
if modelIndex is not None:
print("ENDMDL", file=file)
@@ -412,9 +541,17 @@ def writeFooter(topology, file=sys.stdout):
conectBonds = []
for atom1, atom2 in topology.bonds():
- if atom1.residue.name not in PDBFile._standardResidues or atom2.residue.name not in PDBFile._standardResidues:
+ if (
+ atom1.residue.name not in PDBFile._standardResidues
+ or atom2.residue.name not in PDBFile._standardResidues
+ ):
conectBonds.append((atom1, atom2))
- elif atom1.name == 'SG' and atom2.name == 'SG' and atom1.residue.name == 'CYS' and atom2.residue.name == 'CYS':
+ elif (
+ atom1.name == "SG"
+ and atom2.name == "SG"
+ and atom1.residue.name == "CYS"
+ and atom2.residue.name == "CYS"
+ ):
conectBonds.append((atom1, atom2))
if len(conectBonds) > 0:
@@ -449,7 +586,16 @@ def writeFooter(topology, file=sys.stdout):
for index1 in sorted(atomBonds):
bonded = atomBonds[index1]
while len(bonded) > 4:
- print("CONECT%5s%5s%5s%5s" % (_formatIndex(index1, 5), _formatIndex(bonded[0], 5), _formatIndex(bonded[1], 5), _formatIndex(bonded[2], 5)), file=file)
+ print(
+ "CONECT%5s%5s%5s%5s"
+ % (
+ _formatIndex(index1, 5),
+ _formatIndex(bonded[0], 5),
+ _formatIndex(bonded[1], 5),
+ _formatIndex(bonded[2], 5),
+ ),
+ file=file,
+ )
del bonded[:4]
line = "CONECT%5s" % _formatIndex(index1, 5)
for index2 in bonded:
@@ -464,19 +610,19 @@ def _format_83(f):
gracefully degrade the precision by lopping off some of the decimal
places. If it's much too large, we throw a ValueError"""
if -999.999 < f < 9999.999:
- return '%8.3f' % f
+ return "%8.3f" % f
if -9999999 < f < 99999999:
- return ('%8.3f' % f)[:8]
- raise ValueError('coordinate "%s" could not be represented '
- 'in a width-8 field' % f)
+ return ("%8.3f" % f)[:8]
+ raise ValueError('coordinate "%s" could not be represented ' "in a width-8 field" % f)
+
def _formatIndex(index, places):
"""Create a string representation of an atom or residue index. If the value is larger than can fit
in the available space, switch to hex.
"""
if index < 10**places:
- format = f'%{places}d'
+ format = f"%{places}d"
return format % index
- format = f'%{places}X'
- shiftedIndex = (index - 10**places + 10*16**(places-1)) % (16**places)
- return format % shiftedIndex
\ No newline at end of file
+ format = f"%{places}X"
+ shiftedIndex = (index - 10**places + 10 * 16 ** (places - 1)) % (16**places)
+ return format % shiftedIndex
diff --git a/gufe/vendor/pdb_file/pdbstructure.py b/gufe/vendor/pdb_file/pdbstructure.py
index 76c5f61a..cda68915 100644
--- a/gufe/vendor/pdb_file/pdbstructure.py
+++ b/gufe/vendor/pdb_file/pdbstructure.py
@@ -28,24 +28,24 @@
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
-from __future__ import absolute_import
-from __future__ import print_function
+
__author__ = "Christopher M. Bruns"
__version__ = "1.0"
+import math
+import sys
+import warnings
+from collections import OrderedDict
+
+import numpy as np
import openmm.unit as unit
from . import element
from .unitcell import computePeriodicBoxVectors
-import numpy as np
-import warnings
-import sys
-import math
-from collections import OrderedDict
-class PdbStructure(object):
+class PdbStructure:
"""
PdbStructure object holds a parsed Protein Data Bank format file.
@@ -125,8 +125,7 @@ class PdbStructure(object):
methods.
"""
-
- def __init__(self, input_stream, load_all_models=False, extraParticleIdentifier='EP'):
+ def __init__(self, input_stream, load_all_models=False, extraParticleIdentifier="EP"):
"""Create a PDB model from a PDB file stream.
Parameters
@@ -160,16 +159,16 @@ def _load(self, input_stream):
# Read one line at a time
for pdb_line in input_stream:
if not isinstance(pdb_line, str):
- pdb_line = pdb_line.decode('utf-8')
+ pdb_line = pdb_line.decode("utf-8")
command = pdb_line[:6]
# Look for atoms
if command == "ATOM " or command == "HETATM":
self._add_atom(Atom(pdb_line, self, self.extraParticleIdentifier))
elif command == "CONECT":
atoms = [_parse_atom_index(pdb_line[6:11])]
- for pos in (11,16,21,26):
+ for pos in (11, 16, 21, 26):
try:
- atoms.append(_parse_atom_index(pdb_line[pos:pos+5]))
+ atoms.append(_parse_atom_index(pdb_line[pos : pos + 5]))
except:
pass
self._current_model.connects.append(atoms)
@@ -192,21 +191,30 @@ def _load(self, input_stream):
self._current_model._current_chain._add_ter_record()
self._reset_residue_numbers()
elif command == "CRYST1":
- a_length = float(pdb_line[6:15])*0.1
- b_length = float(pdb_line[15:24])*0.1
- c_length = float(pdb_line[24:33])*0.1
- alpha = float(pdb_line[33:40])*math.pi/180.0
- beta = float(pdb_line[40:47])*math.pi/180.0
- gamma = float(pdb_line[47:54])*math.pi/180.0
+ a_length = float(pdb_line[6:15]) * 0.1
+ b_length = float(pdb_line[15:24]) * 0.1
+ c_length = float(pdb_line[24:33]) * 0.1
+ alpha = float(pdb_line[33:40]) * math.pi / 180.0
+ beta = float(pdb_line[40:47]) * math.pi / 180.0
+ gamma = float(pdb_line[47:54]) * math.pi / 180.0
if 0 not in (a_length, b_length, c_length):
- self._periodic_box_vectors = computePeriodicBoxVectors(a_length, b_length, c_length, alpha, beta, gamma)
+ self._periodic_box_vectors = computePeriodicBoxVectors(
+ a_length, b_length, c_length, alpha, beta, gamma
+ )
elif command == "SEQRES":
chain_id = pdb_line[11]
if len(self.sequences) == 0 or chain_id != self.sequences[-1].chain_id:
self.sequences.append(Sequence(chain_id))
self.sequences[-1].residues.extend(pdb_line[19:].split())
elif command == "MODRES":
- self.modified_residues.append(ModifiedResidue(pdb_line[16], int(pdb_line[18:22]), pdb_line[12:15].strip(), pdb_line[24:27].strip()))
+ self.modified_residues.append(
+ ModifiedResidue(
+ pdb_line[16],
+ int(pdb_line[18:22]),
+ pdb_line[12:15].strip(),
+ pdb_line[24:27].strip(),
+ )
+ )
self._finalize()
def _reset_atom_numbers(self):
@@ -248,30 +256,25 @@ def __getitem__(self, model_number):
return self.models_by_number[model_number]
def __iter__(self):
- for model in self.models:
- yield model
+ yield from self.models
def iter_models(self, use_all_models=False):
if use_all_models:
- for model in self:
- yield model
+ yield from self
elif len(self.models) > 0:
yield self.models[0]
def iter_chains(self, use_all_models=False):
for model in self.iter_models(use_all_models):
- for chain in model.iter_chains():
- yield chain
+ yield from model.iter_chains()
def iter_residues(self, use_all_models=False):
for model in self.iter_models(use_all_models):
- for res in model.iter_residues():
- yield res
+ yield from model.iter_residues()
def iter_atoms(self, use_all_models=False):
for model in self.iter_models(use_all_models):
- for atom in model.iter_atoms():
- yield atom
+ yield from model.iter_atoms()
def iter_positions(self, use_all_models=False, include_alt_loc=False):
"""
@@ -285,15 +288,13 @@ def iter_positions(self, use_all_models=False, include_alt_loc=False):
Get all positions for each atom, or just the first one.
"""
for model in self.iter_models(use_all_models):
- for loc in model.iter_positions(include_alt_loc):
- yield loc
+ yield from model.iter_positions(include_alt_loc)
def __len__(self):
return len(self.models)
def _add_atom(self, atom):
- """
- """
+ """ """
if self._current_model is None:
self._add_model(Model(0))
atom.model_number = self._current_model.number
@@ -310,14 +311,17 @@ def get_periodic_box_vectors(self):
return self._periodic_box_vectors
-class Sequence(object):
+class Sequence:
"""Sequence holds the sequence of a chain, as specified by SEQRES records."""
+
def __init__(self, chain_id):
self.chain_id = chain_id
self.residues = []
-class ModifiedResidue(object):
+
+class ModifiedResidue:
"""ModifiedResidue holds information about a modified residue, as specified by a MODRES record."""
+
def __init__(self, chain_id, number, residue_name, standard_name):
self.chain_id = chain_id
self.number = number
@@ -325,12 +329,13 @@ def __init__(self, chain_id, number, residue_name, standard_name):
self.standard_name = standard_name
-class Model(object):
+class Model:
"""Model holds one model of a PDB structure.
NMR structures usually have multiple models. This represents one
of them.
"""
+
def __init__(self, model_number=1):
self.number = model_number
self.chains = []
@@ -339,8 +344,7 @@ def __init__(self, model_number=1):
self.connects = []
def _add_atom(self, atom):
- """
- """
+ """ """
if len(self.chains) == 0:
self._add_chain(Chain(atom.chain_id))
# Create a new chain if the chain id has changed
@@ -373,23 +377,19 @@ def __iter__(self):
return iter(self.chains)
def iter_chains(self):
- for chain in self:
- yield chain
+ yield from self
def iter_residues(self):
for chain in self:
- for res in chain.iter_residues():
- yield res
+ yield from chain.iter_residues()
def iter_atoms(self):
for chain in self:
- for atom in chain.iter_atoms():
- yield atom
+ yield from chain.iter_atoms()
def iter_positions(self, include_alt_loc=False):
for chain in self:
- for loc in chain.iter_positions(include_alt_loc):
- yield loc
+ yield from chain.iter_positions(include_alt_loc)
def __len__(self):
return len(self.chains)
@@ -404,9 +404,9 @@ def _finalize(self):
for chain in self.chains:
chain._finalize()
-
- class AtomSerialNumber(object):
+ class AtomSerialNumber:
"""pdb.Model inner class for pass-by-reference incrementable serial number"""
+
def __init__(self, val):
self.val = val
@@ -414,8 +414,8 @@ def increment(self):
self.val += 1
-class Chain(object):
- def __init__(self, chain_id=' '):
+class Chain:
+ def __init__(self, chain_id=" "):
self.chain_id = chain_id
self.residues = []
self.has_ter_record = False
@@ -424,26 +424,57 @@ def __init__(self, chain_id=' '):
self.residues_by_number = {}
def _add_atom(self, atom):
- """
- """
+ """ """
# Create a residue if none have been created
if len(self.residues) == 0:
- self._add_residue(Residue(atom.residue_name_with_spaces, atom.residue_number, atom.insertion_code, atom.alternate_location_indicator))
+ self._add_residue(
+ Residue(
+ atom.residue_name_with_spaces,
+ atom.residue_number,
+ atom.insertion_code,
+ atom.alternate_location_indicator,
+ )
+ )
# Create a residue if the residue information has changed
elif self._current_residue.number != atom.residue_number:
- self._add_residue(Residue(atom.residue_name_with_spaces, atom.residue_number, atom.insertion_code, atom.alternate_location_indicator))
+ self._add_residue(
+ Residue(
+ atom.residue_name_with_spaces,
+ atom.residue_number,
+ atom.insertion_code,
+ atom.alternate_location_indicator,
+ )
+ )
elif self._current_residue.insertion_code != atom.insertion_code:
- self._add_residue(Residue(atom.residue_name_with_spaces, atom.residue_number, atom.insertion_code, atom.alternate_location_indicator))
+ self._add_residue(
+ Residue(
+ atom.residue_name_with_spaces,
+ atom.residue_number,
+ atom.insertion_code,
+ atom.alternate_location_indicator,
+ )
+ )
elif self._current_residue.name_with_spaces == atom.residue_name_with_spaces:
# This is a normal case: number, name, and iCode have not changed
pass
- elif atom.alternate_location_indicator != ' ':
+ elif atom.alternate_location_indicator != " ":
# OK - this is a point mutation, Residue._add_atom will know what to do
pass
- else: # Residue name does not match
+ else: # Residue name does not match
# Only residue name does not match
- warnings.warn("WARNING: two consecutive residues with same number (%s, %s)" % (atom, self._current_residue.atoms[-1]))
- self._add_residue(Residue(atom.residue_name_with_spaces, atom.residue_number, atom.insertion_code, atom.alternate_location_indicator))
+ warnings.warn(
+ "WARNING: two consecutive residues with same number ({}, {})".format(
+ atom, self._current_residue.atoms[-1]
+ )
+ )
+ self._add_residue(
+ Residue(
+ atom.residue_name_with_spaces,
+ atom.residue_number,
+ atom.insertion_code,
+ atom.alternate_location_indicator,
+ )
+ )
self._current_residue._add_atom(atom)
def _add_residue(self, residue):
@@ -463,14 +494,24 @@ def write(self, next_serial_number, output_stream=sys.stdout):
residue.write(next_serial_number, output_stream)
if self.has_ter_record:
r = self.residues[-1]
- print("TER %5d %3s %1s%4d%1s" % (next_serial_number.val, r.name_with_spaces, self.chain_id, r.number, r.insertion_code), file=output_stream)
+ print(
+ "TER %5d %3s %1s%4d%1s"
+ % (
+ next_serial_number.val,
+ r.name_with_spaces,
+ self.chain_id,
+ r.number,
+ r.insertion_code,
+ ),
+ file=output_stream,
+ )
next_serial_number.increment()
def _add_ter_record(self):
self.has_ter_record = True
self._finalize()
- def get_residue(self, residue_number, insertion_code=' '):
+ def get_residue(self, residue_number, insertion_code=" "):
return self.residues_by_num_icode[str(residue_number) + insertion_code]
def __contains__(self, residue_number):
@@ -481,22 +522,18 @@ def __getitem__(self, residue_number):
return self.residues_by_number[residue_number]
def __iter__(self):
- for res in self.residues:
- yield res
+ yield from self.residues
def iter_residues(self):
- for res in self:
- yield res
+ yield from self
def iter_atoms(self):
for res in self:
- for atom in res:
- yield atom;
+ yield from res
def iter_positions(self, include_alt_loc=False):
for res in self:
- for loc in res.iter_positions(include_alt_loc):
- yield loc
+ yield from res.iter_positions(include_alt_loc)
def __len__(self):
return len(self.residues)
@@ -508,8 +545,8 @@ def _finalize(self):
residue._finalize()
-class Residue(object):
- def __init__(self, name, number, insertion_code=' ', primary_alternate_location_indicator=' '):
+class Residue:
+ def __init__(self, name, number, insertion_code=" ", primary_alternate_location_indicator=" "):
alt_loc = primary_alternate_location_indicator
self.primary_location_id = alt_loc
self.locations = {}
@@ -524,30 +561,35 @@ def __init__(self, name, number, insertion_code=' ', primary_alternate_location_
self._current_atom = None
def _add_atom(self, atom):
- """
- """
+ """ """
alt_loc = atom.alternate_location_indicator
if alt_loc not in self.locations:
self.locations[alt_loc] = Residue.Location(alt_loc, atom.residue_name_with_spaces)
assert atom.residue_number == self.number
assert atom.insertion_code == self.insertion_code
# Check whether this is an existing atom with another position
- if (atom.name_with_spaces in self.atoms_by_name):
+ if atom.name_with_spaces in self.atoms_by_name:
old_atom = self.atoms_by_name[atom.name_with_spaces]
# Unless this is a duplicated atom (warn about file error)
if atom.alternate_location_indicator in old_atom.locations:
- warnings.warn("WARNING: duplicate atom (%s, %s)" % (atom, old_atom._pdb_string(old_atom.serial_number, atom.alternate_location_indicator)))
+ warnings.warn(
+ "WARNING: duplicate atom (%s, %s)"
+ % (
+ atom,
+ old_atom._pdb_string(old_atom.serial_number, atom.alternate_location_indicator),
+ )
+ )
else:
for alt_loc, position in atom.locations.items():
old_atom.locations[alt_loc] = position
- return # no new atom added
+ return # no new atom added
# actually use new atom
self.atoms_by_name[atom.name] = atom
self.atoms_by_name[atom.name_with_spaces] = atom
self.atoms.append(atom)
self._current_atom = atom
- def write(self, next_serial_number, output_stream=sys.stdout, alt_loc = "*"):
+ def write(self, next_serial_number, output_stream=sys.stdout, alt_loc="*"):
for atom in self.atoms:
atom.write(next_serial_number, output_stream, alt_loc)
@@ -567,19 +609,26 @@ def set_name_with_spaces(self, name, alt_loc=None):
loc = self.locations[alt_loc]
loc.name_with_spaces = name
loc.name = name.strip()
+
def get_name_with_spaces(self, alt_loc=None):
if alt_loc is None:
alt_loc = self.primary_location_id
loc = self.locations[alt_loc]
return loc.name_with_spaces
- name_with_spaces = property(get_name_with_spaces, set_name_with_spaces, doc='four-character residue name including spaces')
+
+ name_with_spaces = property(
+ get_name_with_spaces,
+ set_name_with_spaces,
+ doc="four-character residue name including spaces",
+ )
def get_name(self, alt_loc=None):
if alt_loc is None:
alt_loc = self.primary_location_id
loc = self.locations[alt_loc]
return loc.name
- name = property(get_name, doc='residue name')
+
+ name = property(get_name, doc="residue name")
def get_atom(self, atom_name):
return self.atoms_by_name[atom_name]
@@ -613,8 +662,7 @@ def __iter__(self):
ATOM 192 CB CYS A 42 38.949 -6.825 12.002 1.00 9.67 C
ATOM 193 SG CYS A 42 37.557 -7.514 12.922 1.00 20.12 S
"""
- for atom in self.iter_atoms():
- yield atom
+ yield from self.iter_atoms()
# Three possibilities: primary alt_loc, certain alt_loc, or all alt_locs
def iter_atoms(self, alt_loc=None):
@@ -628,10 +676,10 @@ def iter_atoms(self, alt_loc=None):
locs = list(alt_loc)
# If an atom has any location in alt_loc, emit the atom
for atom in self.atoms:
- use_atom = False # start pessimistic
+ use_atom = False # start pessimistic
for loc2 in atom.locations.keys():
# print "#%s#%s" % (loc2,locs)
- if locs is None: # means all locations
+ if locs is None: # means all locations
use_atom = True
elif loc2 in locs:
use_atom = True
@@ -667,8 +715,7 @@ def iter_positions(self, include_alt_loc=False):
"""
for atom in self:
if include_alt_loc:
- for loc in atom.iter_positions():
- yield loc
+ yield from atom.iter_positions()
else:
yield atom.position
@@ -680,15 +727,16 @@ class Location:
"""
Inner class of residue to allow different residue names for different alternate_locations.
"""
+
def __init__(self, alternate_location_indicator, residue_name_with_spaces):
self.alternate_location_indicator = alternate_location_indicator
self.residue_name_with_spaces = residue_name_with_spaces
-class Atom(object):
- """Atom represents one atom in a PDB structure.
- """
- def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier='EP'):
+class Atom:
+ """Atom represents one atom in a PDB structure."""
+
+ def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier="EP"):
"""Create a new pdb.Atom from an ATOM or HETATM line.
Example line:
@@ -741,7 +789,7 @@ def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier='EP'):
if possible_fourth_character != " ":
# Fourth character should only be there if official 3 are already full
if len(self.residue_name_with_spaces.strip()) != 3:
- raise ValueError('Misaligned residue name: %s' % pdb_line)
+ raise ValueError("Misaligned residue name: %s" % pdb_line)
self.residue_name_with_spaces += possible_fourth_character
self.residue_name = self.residue_name_with_spaces.strip()
@@ -754,7 +802,11 @@ def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier='EP'):
except:
# When VMD runs out of hex values it starts filling the residue ID field with ****.
# Look at the most recent atoms to figure out whether this is a new residue or not.
- if pdbstructure._current_model is None or pdbstructure._current_model._current_chain is None or pdbstructure._current_model._current_chain._current_residue is None:
+ if (
+ pdbstructure._current_model is None
+ or pdbstructure._current_model._current_chain is None
+ or pdbstructure._current_model._current_chain._current_residue is None
+ ):
# This is the first residue in the model.
self.residue_number = pdbstructure._next_residue_number
else:
@@ -781,17 +833,25 @@ def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier='EP'):
except:
temperature_factor = unit.Quantity(0.0, unit.angstroms**2)
self.locations = {}
- loc = Atom.Location(alternate_location_indicator, unit.Quantity(np.array([x,y,z]), unit.angstroms), occupancy, temperature_factor, self.residue_name_with_spaces)
+ loc = Atom.Location(
+ alternate_location_indicator,
+ unit.Quantity(np.array([x, y, z]), unit.angstroms),
+ occupancy,
+ temperature_factor,
+ self.residue_name_with_spaces,
+ )
self.locations[alternate_location_indicator] = loc
self.default_location_id = alternate_location_indicator
# segment id, element_symbol, and formal_charge are not always present
self.segment_id = pdb_line[72:76].strip()
self.element_symbol = pdb_line[76:78].strip()
- try: self.formal_charge = int(pdb_line[78:80])
- except ValueError: self.formal_charge = None
+ try:
+ self.formal_charge = int(pdb_line[78:80])
+ except ValueError:
+ self.formal_charge = None
# figure out atom element
if self.element_symbol == extraParticleIdentifier:
- self.element = 'EP'
+ self.element = "EP"
else:
try:
# Try to find a sensible element symbol from columns 76-77
@@ -799,8 +859,8 @@ def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier='EP'):
except KeyError:
self.element = None
if pdbstructure is not None:
- pdbstructure._next_atom_number = self.serial_number+1
- pdbstructure._next_residue_number = self.residue_number+1
+ pdbstructure._next_atom_number = self.serial_number + 1
+ pdbstructure._next_residue_number = self.residue_number + 1
def iter_locations(self):
"""
@@ -835,8 +895,7 @@ def iter_coordinates(self):
22.607 A
20.046 A
"""
- for coord in self.position:
- yield coord
+ yield from self.position
# Hide existence of multiple alternate locations to avoid scaring casual users
def get_location(self, location_id=None):
@@ -844,38 +903,51 @@ def get_location(self, location_id=None):
if id is None:
id = self.default_location_id
return self.locations[id]
+
def set_location(self, new_location, location_id=None):
id = location_id
if id is None:
id = self.default_location_id
self.locations[id] = new_location
- location = property(get_location, set_location, doc='default Atom.Location object')
+
+ location = property(get_location, set_location, doc="default Atom.Location object")
def get_position(self):
return self.location.position
+
def set_position(self, coords):
self.location.position = coords
- position = property(get_position, set_position, doc='orthogonal coordinates')
+
+ position = property(get_position, set_position, doc="orthogonal coordinates")
def get_alternate_location_indicator(self):
return self.location.alternate_location_indicator
+
alternate_location_indicator = property(get_alternate_location_indicator)
def get_occupancy(self):
return self.location.occupancy
+
occupancy = property(get_occupancy)
def get_temperature_factor(self):
return self.location.temperature_factor
+
temperature_factor = property(get_temperature_factor)
- def get_x(self): return self.position[0]
+ def get_x(self):
+ return self.position[0]
+
x = property(get_x)
- def get_y(self): return self.position[1]
+ def get_y(self):
+ return self.position[1]
+
y = property(get_y)
- def get_z(self): return self.position[2]
+ def get_z(self):
+ return self.position[2]
+
z = property(get_z)
def _pdb_string(self, serial_number=None, alternate_location_indicator=None):
@@ -893,26 +965,32 @@ def _pdb_string(self, serial_number=None, alternate_location_indicator=None):
long_res_name += " "
assert len(long_res_name) == 4
names = "%-6s%5d %4s%1s%4s%1s%4d%1s " % (
- self.record_name, serial_number, \
- self.name_with_spaces, alternate_location_indicator, \
- long_res_name, self.chain_id, \
- self.residue_number, self.insertion_code)
- numbers = "%8.3f%8.3f%8.3f%6.2f%6.2f " % (
- self.x.value_in_unit(unit.angstroms), \
- self.y.value_in_unit(unit.angstroms), \
- self.z.value_in_unit(unit.angstroms), \
- self.occupancy, \
- self.temperature_factor.value_in_unit(unit.angstroms * unit.angstroms))
- end = "%-4s%2s" % (\
- self.segment_id, self.element_symbol)
+ self.record_name,
+ serial_number,
+ self.name_with_spaces,
+ alternate_location_indicator,
+ long_res_name,
+ self.chain_id,
+ self.residue_number,
+ self.insertion_code,
+ )
+ numbers = "{:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f} ".format(
+ self.x.value_in_unit(unit.angstroms),
+ self.y.value_in_unit(unit.angstroms),
+ self.z.value_in_unit(unit.angstroms),
+ self.occupancy,
+ self.temperature_factor.value_in_unit(unit.angstroms * unit.angstroms),
+ )
+ end = "%-4s%2s" % (self.segment_id, self.element_symbol)
formal_charge = " "
- if (self.formal_charge != None): formal_charge = "%+2d" % self.formal_charge
- return names+numbers+end+formal_charge
+ if self.formal_charge != None:
+ formal_charge = "%+2d" % self.formal_charge
+ return names + numbers + end + formal_charge
def __str__(self):
return self._pdb_string(self.serial_number, self.alternate_location_indicator)
- def write(self, next_serial_number, output_stream=sys.stdout, alt_loc = "*"):
+ def write(self, next_serial_number, output_stream=sys.stdout, alt_loc="*"):
"""
alt_loc = "*" means write all alternate locations
alt_loc = None means write just the primary location
@@ -935,18 +1013,26 @@ def set_name_with_spaces(self, name):
assert len(name) == 4
self._name_with_spaces = name
self._name = name.strip()
+
def get_name_with_spaces(self):
return self._name_with_spaces
- name_with_spaces = property(get_name_with_spaces, set_name_with_spaces, doc='four-character residue name including spaces')
+
+ name_with_spaces = property(
+ get_name_with_spaces,
+ set_name_with_spaces,
+ doc="four-character residue name including spaces",
+ )
def get_name(self):
return self._name
- name = property(get_name, doc='residue name')
- class Location(object):
+ name = property(get_name, doc="residue name")
+
+ class Location:
"""
Inner class of Atom for holding alternate locations
"""
+
def __init__(self, alt_loc, position, occupancy, temperature_factor, residue_name):
self.alternate_location_indicator = alt_loc
self.position = position
@@ -968,8 +1054,7 @@ def __iter__(self):
2 A
3 A
"""
- for coord in self.position:
- yield coord
+ yield from self.position
def __str__(self):
return str(self.position)
@@ -982,14 +1067,17 @@ def _parse_atom_index(index):
except:
return int(index, 16) - 0xA0000 + 100000
+
# run module directly for testing
-if __name__=='__main__':
+if __name__ == "__main__":
# Test the examples in the docstrings
- import doctest, sys
+ import doctest
+ import sys
+
doctest.testmod(sys.modules[__name__])
- import os
import gzip
+ import os
import re
import time
@@ -1006,7 +1094,7 @@ def _parse_atom_index(index):
assert a.residue_number == 299
assert a.insertion_code == " "
assert a.alternate_location_indicator == " "
- assert a.x == 6.167 * unit.angstroms
+ assert a.x == 6.167 * unit.angstroms
assert a.y == 22.607 * unit.angstroms
assert a.z == 20.046 * unit.angstroms
assert a.occupancy == 1.00
@@ -1021,8 +1109,9 @@ def _parse_atom_index(index):
# misaligned residue name - bad
try:
a = Atom("ATOM 2209 CB TYRA 299 6.167 22.607 20.046 1.00 8.12 C")
- assert(False)
- except ValueError: pass
+ assert False
+ except ValueError:
+ pass
# four character residue name -- not so bad
a = Atom("ATOM 2209 CB NTYRA 299 6.167 22.607 20.046 1.00 8.12 C")
@@ -1070,18 +1159,19 @@ def parse_one_pdb(pdb_file_name):
subdir = "ae"
full_subdir = os.path.join(pdb_dir, subdir)
for pdb_file in os.listdir(full_subdir):
- if not re.match("pdb.%2s.\.ent\.gz" % subdir , pdb_file):
+ if not re.match(r"pdb.%2s.\.ent\.gz" % subdir, pdb_file):
continue
full_pdb_file = os.path.join(full_subdir, pdb_file)
parse_one_pdb(full_pdb_file)
if parse_entire_pdb:
for subdir in os.listdir(pdb_dir):
- if not len(subdir) == 2: continue
+ if not len(subdir) == 2:
+ continue
full_subdir = os.path.join(pdb_dir, subdir)
if not os.path.isdir(full_subdir):
continue
for pdb_file in os.listdir(full_subdir):
- if not re.match("pdb.%2s.\.ent\.gz" % subdir , pdb_file):
+ if not re.match(r"pdb.%2s.\.ent\.gz" % subdir, pdb_file):
continue
full_pdb_file = os.path.join(full_subdir, pdb_file)
parse_one_pdb(full_pdb_file)
@@ -1098,4 +1188,4 @@ def parse_one_pdb(pdb_file_name):
print("%d residues found" % residue_count)
print("%d chains found" % chain_count)
print("%d models found" % model_count)
- print("%d structures found" % structure_count)
\ No newline at end of file
+ print("%d structures found" % structure_count)
diff --git a/gufe/vendor/pdb_file/pdbxfile.py b/gufe/vendor/pdb_file/pdbxfile.py
index 60d83246..e5d6433f 100644
--- a/gufe/vendor/pdb_file/pdbxfile.py
+++ b/gufe/vendor/pdb_file/pdbxfile.py
@@ -28,27 +28,25 @@
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
-from __future__ import division, absolute_import, print_function
__author__ = "Peter Eastman"
__version__ = "2.0"
-from openmm.unit import nanometers, angstroms, is_quantity, norm, Quantity
-
-
-import sys
import math
-import numpy as np
-
+import sys
from datetime import date
+import numpy as np
+from openmm.unit import Quantity, angstroms, is_quantity, nanometers, norm
+
+from . import element as elem
+from .pdbfile import PDBFile
from .PdbxReader import PdbxReader
-from .unitcell import computePeriodicBoxVectors, computeLengthsAndAngles
from .topology import Topology
-from .pdbfile import PDBFile
-from . import element as elem
+from .unitcell import computeLengthsAndAngles, computePeriodicBoxVectors
+
-class PDBxFile(object):
+class PDBxFile:
"""PDBxFile parses a PDBx/mmCIF file and constructs a Topology and a set of atom positions from it."""
def __init__(self, file):
@@ -84,52 +82,56 @@ def __init__(self, file):
# Build the topology.
- atomData = block.getObj('atom_site')
- atomNameCol = atomData.getAttributeIndex('auth_atom_id')
+ atomData = block.getObj("atom_site")
+ atomNameCol = atomData.getAttributeIndex("auth_atom_id")
if atomNameCol == -1:
- atomNameCol = atomData.getAttributeIndex('label_atom_id')
- atomIdCol = atomData.getAttributeIndex('id')
- resNameCol = atomData.getAttributeIndex('auth_comp_id')
+ atomNameCol = atomData.getAttributeIndex("label_atom_id")
+ atomIdCol = atomData.getAttributeIndex("id")
+ resNameCol = atomData.getAttributeIndex("auth_comp_id")
if resNameCol == -1:
- resNameCol = atomData.getAttributeIndex('label_comp_id')
- resNumCol = atomData.getAttributeIndex('auth_seq_id')
+ resNameCol = atomData.getAttributeIndex("label_comp_id")
+ resNumCol = atomData.getAttributeIndex("auth_seq_id")
if resNumCol == -1:
- resNumCol = atomData.getAttributeIndex('label_seq_id')
- resInsertionCol = atomData.getAttributeIndex('pdbx_PDB_ins_code')
- chainIdCol = atomData.getAttributeIndex('auth_asym_id')
+ resNumCol = atomData.getAttributeIndex("label_seq_id")
+ resInsertionCol = atomData.getAttributeIndex("pdbx_PDB_ins_code")
+ chainIdCol = atomData.getAttributeIndex("auth_asym_id")
if chainIdCol == -1:
- chainIdCol = atomData.getAttributeIndex('label_asym_id')
+ chainIdCol = atomData.getAttributeIndex("label_asym_id")
altChainIdCol = -1
else:
- altChainIdCol = atomData.getAttributeIndex('label_asym_id')
+ altChainIdCol = atomData.getAttributeIndex("label_asym_id")
if altChainIdCol != -1:
# Figure out which column is best to use for chain IDs.
-
- idSet = set(row[chainIdCol] for row in atomData.getRowList())
- altIdSet = set(row[altChainIdCol] for row in atomData.getRowList())
+
+ idSet = {row[chainIdCol] for row in atomData.getRowList()}
+ altIdSet = {row[altChainIdCol] for row in atomData.getRowList()}
if len(altIdSet) > len(idSet):
chainIdCol, altChainIdCol = (altChainIdCol, chainIdCol)
- elementCol = atomData.getAttributeIndex('type_symbol')
- altIdCol = atomData.getAttributeIndex('label_alt_id')
- modelCol = atomData.getAttributeIndex('pdbx_PDB_model_num')
- xCol = atomData.getAttributeIndex('Cartn_x')
- yCol = atomData.getAttributeIndex('Cartn_y')
- zCol = atomData.getAttributeIndex('Cartn_z')
+ elementCol = atomData.getAttributeIndex("type_symbol")
+ altIdCol = atomData.getAttributeIndex("label_alt_id")
+ modelCol = atomData.getAttributeIndex("pdbx_PDB_model_num")
+ xCol = atomData.getAttributeIndex("Cartn_x")
+ yCol = atomData.getAttributeIndex("Cartn_y")
+ zCol = atomData.getAttributeIndex("Cartn_z")
lastChainId = None
lastAltChainId = None
lastResId = None
- lastInsertionCode = ''
+ lastInsertionCode = ""
atomTable = {}
atomsInResidue = set()
models = []
for row in atomData.getRowList():
- atomKey = ((row[resNumCol], row[chainIdCol], row[atomNameCol]))
- model = ('1' if modelCol == -1 else row[modelCol])
+ atomKey = (row[resNumCol], row[chainIdCol], row[atomNameCol])
+ model = "1" if modelCol == -1 else row[modelCol]
if model not in models:
models.append(model)
self._positions.append([])
modelIndex = models.index(model)
- if row[altIdCol] != '.' and atomKey in atomTable and len(self._positions[modelIndex]) > atomTable[atomKey].index:
+ if (
+ row[altIdCol] != "."
+ and atomKey in atomTable
+ and len(self._positions[modelIndex]) > atomTable[atomKey].index
+ ):
# This row is an alternate position for an existing atom, so ignore it.
continue
@@ -137,11 +139,11 @@ def __init__(self, file):
# This row defines a new atom.
if resInsertionCol == -1:
- insertionCode = ''
+ insertionCode = ""
else:
insertionCode = row[resInsertionCol]
- if insertionCode in ('.', '?'):
- insertionCode = ''
+ if insertionCode in (".", "?"):
+ insertionCode = ""
if lastChainId != row[chainIdCol] or (altChainIdCol != -1 and lastAltChainId != row[altChainIdCol]):
# The start of a new chain.
chain = top.addChain(row[chainIdCol])
@@ -149,9 +151,14 @@ def __init__(self, file):
lastResId = None
if altChainIdCol != -1:
lastAltChainId = row[altChainIdCol]
- if lastResId != row[resNumCol] or lastChainId != row[chainIdCol] or lastInsertionCode != insertionCode or (lastResId == '.' and row[atomNameCol] in atomsInResidue):
+ if (
+ lastResId != row[resNumCol]
+ or lastChainId != row[chainIdCol]
+ or lastInsertionCode != insertionCode
+ or (lastResId == "." and row[atomNameCol] in atomsInResidue)
+ ):
# The start of a new residue.
- resId = (None if resNumCol == -1 else row[resNumCol])
+ resId = None if resNumCol == -1 else row[resNumCol]
resIC = insertionCode
res = top.addResidue(row[resNameCol], chain, resId, resIC)
lastResId = row[resNumCol]
@@ -171,12 +178,18 @@ def __init__(self, file):
try:
atom = atomTable[atomKey]
except KeyError:
- raise ValueError('Unknown atom %s in residue %s %s for model %s' % (row[atomNameCol], row[resNameCol], row[resNumCol], model))
+ raise ValueError(
+ "Unknown atom %s in residue %s %s for model %s"
+ % (row[atomNameCol], row[resNameCol], row[resNumCol], model)
+ )
if atom.index != len(self._positions[modelIndex]):
- raise ValueError('Atom %s for model %s does not match the order of atoms for model %s' % (row[atomIdCol], model, models[0]))
- self._positions[modelIndex].append(np.array([float(row[xCol]), float(row[yCol]), float(row[zCol])])*0.1)
+ raise ValueError(
+ "Atom %s for model %s does not match the order of atoms for model %s"
+ % (row[atomIdCol], model, models[0])
+ )
+ self._positions[modelIndex].append(np.array([float(row[xCol]), float(row[yCol]), float(row[zCol])]) * 0.1)
for i in range(len(self._positions)):
- self._positions[i] = self._positions[i]*nanometers
+ self._positions[i] = self._positions[i] * nanometers
## The atom positions read from the PDBx/mmCIF file. If the file contains multiple frames, these are the positions in the first frame.
self.positions = self._positions[0]
self.topology.createStandardBonds()
@@ -184,28 +197,34 @@ def __init__(self, file):
# Record unit cell information, if present.
- cell = block.getObj('cell')
+ cell = block.getObj("cell")
if cell is not None and cell.getRowCount() > 0:
row = cell.getRow(0)
- (a, b, c) = [float(row[cell.getAttributeIndex(attribute)])*0.1 for attribute in ('length_a', 'length_b', 'length_c')]
- (alpha, beta, gamma) = [float(row[cell.getAttributeIndex(attribute)])*math.pi/180.0 for attribute in ('angle_alpha', 'angle_beta', 'angle_gamma')]
+ (a, b, c) = (
+ float(row[cell.getAttributeIndex(attribute)]) * 0.1
+ for attribute in ("length_a", "length_b", "length_c")
+ )
+ (alpha, beta, gamma) = (
+ float(row[cell.getAttributeIndex(attribute)]) * math.pi / 180.0
+ for attribute in ("angle_alpha", "angle_beta", "angle_gamma")
+ )
self.topology.setPeriodicBoxVectors(computePeriodicBoxVectors(a, b, c, alpha, beta, gamma))
# Add bonds based on struct_conn records.
- connectData = block.getObj('struct_conn')
+ connectData = block.getObj("struct_conn")
if connectData is not None:
- res1Col = connectData.getAttributeIndex('ptnr1_label_seq_id')
- res2Col = connectData.getAttributeIndex('ptnr2_label_seq_id')
- atom1Col = connectData.getAttributeIndex('ptnr1_label_atom_id')
- atom2Col = connectData.getAttributeIndex('ptnr2_label_atom_id')
- asym1Col = connectData.getAttributeIndex('ptnr1_label_asym_id')
- asym2Col = connectData.getAttributeIndex('ptnr2_label_asym_id')
- typeCol = connectData.getAttributeIndex('conn_type_id')
+ res1Col = connectData.getAttributeIndex("ptnr1_label_seq_id")
+ res2Col = connectData.getAttributeIndex("ptnr2_label_seq_id")
+ atom1Col = connectData.getAttributeIndex("ptnr1_label_atom_id")
+ atom2Col = connectData.getAttributeIndex("ptnr2_label_atom_id")
+ asym1Col = connectData.getAttributeIndex("ptnr1_label_asym_id")
+ asym2Col = connectData.getAttributeIndex("ptnr2_label_asym_id")
+ typeCol = connectData.getAttributeIndex("conn_type_id")
connectBonds = []
for row in connectData.getRowList():
type = row[typeCol][:6]
- if type in ('covale', 'disulf', 'modres'):
+ if type in ("covale", "disulf", "modres"):
key1 = (row[res1Col], row[asym1Col], row[atom1Col])
key2 = (row[res2Col], row[asym2Col], row[atom2Col])
if key1 in atomTable and key2 in atomTable:
@@ -239,15 +258,17 @@ def getPositions(self, asnp=False, frame=0):
"""
if asnp:
if self._npPositions is None:
- self._npPositions = [None]*len(self._positions)
+ self._npPositions = [None] * len(self._positions)
if self._npPositions[frame] is None:
- self._npPositions[frame] = Quantity(np.array(self._positions[frame].value_in_unit(nanometers)), nanometers)
+ self._npPositions[frame] = Quantity(
+ np.array(self._positions[frame].value_in_unit(nanometers)),
+ nanometers,
+ )
return self._npPositions[frame]
return self._positions[frame]
@staticmethod
- def writeFile(topology, positions, file=sys.stdout, keepIds=False,
- entry=None):
+ def writeFile(topology, positions, file=sys.stdout, keepIds=False, entry=None):
"""Write a PDBx/mmCIF file containing a single model.
Parameters
@@ -288,46 +309,54 @@ def writeHeader(topology, file=sys.stdout, entry=None, keepIds=False):
PDBx/mmCIF format. Otherwise, the output file will be invalid.
"""
if entry is not None:
- print('data_%s' % entry, file=file)
+ print("data_%s" % entry, file=file)
else:
- print('data_cell', file=file)
+ print("data_cell", file=file)
print("# Created with OpenMM %s" % (str(date.today())), file=file)
- print('#', file=file)
+ print("#", file=file)
vectors = topology.getPeriodicBoxVectors()
if vectors is not None:
a, b, c, alpha, beta, gamma = computeLengthsAndAngles(vectors)
- RAD_TO_DEG = 180/math.pi
- print('_cell.length_a %10.4f' % (a*10), file=file)
- print('_cell.length_b %10.4f' % (b*10), file=file)
- print('_cell.length_c %10.4f' % (c*10), file=file)
- print('_cell.angle_alpha %10.4f' % (alpha*RAD_TO_DEG), file=file)
- print('_cell.angle_beta %10.4f' % (beta*RAD_TO_DEG), file=file)
- print('_cell.angle_gamma %10.4f' % (gamma*RAD_TO_DEG), file=file)
- print('#', file=file)
+ RAD_TO_DEG = 180 / math.pi
+ print("_cell.length_a %10.4f" % (a * 10), file=file)
+ print("_cell.length_b %10.4f" % (b * 10), file=file)
+ print("_cell.length_c %10.4f" % (c * 10), file=file)
+ print("_cell.angle_alpha %10.4f" % (alpha * RAD_TO_DEG), file=file)
+ print("_cell.angle_beta %10.4f" % (beta * RAD_TO_DEG), file=file)
+ print("_cell.angle_gamma %10.4f" % (gamma * RAD_TO_DEG), file=file)
+ print("#", file=file)
# Identify bonds that should be listed in the file.
bonds = []
for atom1, atom2 in topology.bonds():
- if atom1.residue.name not in PDBFile._standardResidues or atom2.residue.name not in PDBFile._standardResidues:
+ if (
+ atom1.residue.name not in PDBFile._standardResidues
+ or atom2.residue.name not in PDBFile._standardResidues
+ ):
bonds.append((atom1, atom2))
- elif atom1.name == 'SG' and atom2.name == 'SG' and atom1.residue.name == 'CYS' and atom2.residue.name == 'CYS':
+ elif (
+ atom1.name == "SG"
+ and atom2.name == "SG"
+ and atom1.residue.name == "CYS"
+ and atom2.residue.name == "CYS"
+ ):
bonds.append((atom1, atom2))
if len(bonds) > 0:
# Write the bond information.
- print('loop_', file=file)
- print('_struct_conn.id', file=file)
- print('_struct_conn.conn_type_id', file=file)
- print('_struct_conn.ptnr1_label_asym_id', file=file)
- print('_struct_conn.ptnr1_label_comp_id', file=file)
- print('_struct_conn.ptnr1_label_seq_id', file=file)
- print('_struct_conn.ptnr1_label_atom_id', file=file)
- print('_struct_conn.ptnr2_label_asym_id', file=file)
- print('_struct_conn.ptnr2_label_comp_id', file=file)
- print('_struct_conn.ptnr2_label_seq_id', file=file)
- print('_struct_conn.ptnr2_label_atom_id', file=file)
+ print("loop_", file=file)
+ print("_struct_conn.id", file=file)
+ print("_struct_conn.conn_type_id", file=file)
+ print("_struct_conn.ptnr1_label_asym_id", file=file)
+ print("_struct_conn.ptnr1_label_comp_id", file=file)
+ print("_struct_conn.ptnr1_label_seq_id", file=file)
+ print("_struct_conn.ptnr1_label_atom_id", file=file)
+ print("_struct_conn.ptnr2_label_asym_id", file=file)
+ print("_struct_conn.ptnr2_label_comp_id", file=file)
+ print("_struct_conn.ptnr2_label_seq_id", file=file)
+ print("_struct_conn.ptnr2_label_atom_id", file=file)
chainIds = {}
resIds = {}
if keepIds:
@@ -336,49 +365,63 @@ def writeHeader(topology, file=sys.stdout, entry=None, keepIds=False):
for res in topology.residues():
resIds[res] = res.id
else:
- for (chainIndex, chain) in enumerate(topology.chains()):
- chainIds[chain] = chr(ord('A')+chainIndex%26)
- for (resIndex, res) in enumerate(chain.residues()):
- resIds[res] = resIndex+1
+ for chainIndex, chain in enumerate(topology.chains()):
+ chainIds[chain] = chr(ord("A") + chainIndex % 26)
+ for resIndex, res in enumerate(chain.residues()):
+ resIds[res] = resIndex + 1
for i, (atom1, atom2) in enumerate(bonds):
- if atom1.residue.name == 'CYS' and atom2.residue.name == 'CYS':
- bondType = 'disulf'
+ if atom1.residue.name == "CYS" and atom2.residue.name == "CYS":
+ bondType = "disulf"
else:
- bondType = 'covale'
+ bondType = "covale"
line = "bond%d %s %s %-4s %5s %-4s %s %-4s %5s %-4s"
- print(line % (i+1, bondType, chainIds[atom1.residue.chain], atom1.residue.name, resIds[atom1.residue], atom1.name,
- chainIds[atom2.residue.chain], atom2.residue.name, resIds[atom2.residue], atom2.name), file=file)
- print('#', file=file)
+ print(
+ line
+ % (
+ i + 1,
+ bondType,
+ chainIds[atom1.residue.chain],
+ atom1.residue.name,
+ resIds[atom1.residue],
+ atom1.name,
+ chainIds[atom2.residue.chain],
+ atom2.residue.name,
+ resIds[atom2.residue],
+ atom2.name,
+ ),
+ file=file,
+ )
+ print("#", file=file)
# Write the header for the atom coordinates.
- print('loop_', file=file)
- print('_atom_site.group_PDB', file=file)
- print('_atom_site.id', file=file)
- print('_atom_site.type_symbol', file=file)
- print('_atom_site.label_atom_id', file=file)
- print('_atom_site.label_alt_id', file=file)
- print('_atom_site.label_comp_id', file=file)
- print('_atom_site.label_asym_id', file=file)
- print('_atom_site.label_entity_id', file=file)
- print('_atom_site.label_seq_id', file=file)
- print('_atom_site.pdbx_PDB_ins_code', file=file)
- print('_atom_site.Cartn_x', file=file)
- print('_atom_site.Cartn_y', file=file)
- print('_atom_site.Cartn_z', file=file)
- print('_atom_site.occupancy', file=file)
- print('_atom_site.B_iso_or_equiv', file=file)
- print('_atom_site.Cartn_x_esd', file=file)
- print('_atom_site.Cartn_y_esd', file=file)
- print('_atom_site.Cartn_z_esd', file=file)
- print('_atom_site.occupancy_esd', file=file)
- print('_atom_site.B_iso_or_equiv_esd', file=file)
- print('_atom_site.pdbx_formal_charge', file=file)
- print('_atom_site.auth_seq_id', file=file)
- print('_atom_site.auth_comp_id', file=file)
- print('_atom_site.auth_asym_id', file=file)
- print('_atom_site.auth_atom_id', file=file)
- print('_atom_site.pdbx_PDB_model_num', file=file)
+ print("loop_", file=file)
+ print("_atom_site.group_PDB", file=file)
+ print("_atom_site.id", file=file)
+ print("_atom_site.type_symbol", file=file)
+ print("_atom_site.label_atom_id", file=file)
+ print("_atom_site.label_alt_id", file=file)
+ print("_atom_site.label_comp_id", file=file)
+ print("_atom_site.label_asym_id", file=file)
+ print("_atom_site.label_entity_id", file=file)
+ print("_atom_site.label_seq_id", file=file)
+ print("_atom_site.pdbx_PDB_ins_code", file=file)
+ print("_atom_site.Cartn_x", file=file)
+ print("_atom_site.Cartn_y", file=file)
+ print("_atom_site.Cartn_z", file=file)
+ print("_atom_site.occupancy", file=file)
+ print("_atom_site.B_iso_or_equiv", file=file)
+ print("_atom_site.Cartn_x_esd", file=file)
+ print("_atom_site.Cartn_y_esd", file=file)
+ print("_atom_site.Cartn_z_esd", file=file)
+ print("_atom_site.occupancy_esd", file=file)
+ print("_atom_site.B_iso_or_equiv_esd", file=file)
+ print("_atom_site.pdbx_formal_charge", file=file)
+ print("_atom_site.auth_seq_id", file=file)
+ print("_atom_site.auth_comp_id", file=file)
+ print("_atom_site.auth_asym_id", file=file)
+ print("_atom_site.auth_atom_id", file=file)
+ print("_atom_site.pdbx_PDB_model_num", file=file)
@staticmethod
def writeModel(topology, positions, file=sys.stdout, modelIndex=1, keepIds=False):
@@ -401,30 +444,30 @@ def writeModel(topology, positions, file=sys.stdout, modelIndex=1, keepIds=False
PDBx/mmCIF format. Otherwise, the output file will be invalid.
"""
if len(list(topology.atoms())) != len(positions):
- raise ValueError('The number of positions must match the number of atoms')
+ raise ValueError("The number of positions must match the number of atoms")
if is_quantity(positions):
positions = positions.value_in_unit(angstroms)
if any(math.isnan(norm(pos)) for pos in positions):
- raise ValueError('Particle position is NaN')
+ raise ValueError("Particle position is NaN")
if any(math.isinf(norm(pos)) for pos in positions):
- raise ValueError('Particle position is infinite')
+ raise ValueError("Particle position is infinite")
nonHeterogens = PDBFile._standardResidues[:]
- nonHeterogens.remove('HOH')
+ nonHeterogens.remove("HOH")
atomIndex = 1
posIndex = 0
- for (chainIndex, chain) in enumerate(topology.chains()):
+ for chainIndex, chain in enumerate(topology.chains()):
if keepIds:
chainName = chain.id
else:
- chainName = chr(ord('A')+chainIndex%26)
+ chainName = chr(ord("A") + chainIndex % 26)
residues = list(chain.residues())
- for (resIndex, res) in enumerate(residues):
+ for resIndex, res in enumerate(residues):
if keepIds:
resId = res.id
- resIC = (res.insertionCode if res.insertionCode.strip() else '.')
+ resIC = res.insertionCode if res.insertionCode.strip() else "."
else:
resId = resIndex + 1
- resIC = '.'
+ resIC = "."
if res.name in nonHeterogens:
recordName = "ATOM"
else:
@@ -434,9 +477,29 @@ def writeModel(topology, positions, file=sys.stdout, modelIndex=1, keepIds=False
if atom.element is not None:
symbol = atom.element.symbol
else:
- symbol = '?'
+ symbol = "?"
line = "%s %5d %-3s %-4s . %-4s %s ? %5s %s %10.4f %10.4f %10.4f 0.0 0.0 ? ? ? ? ? . %5s %4s %s %4s %5d"
- print(line % (recordName, atomIndex, symbol, atom.name, res.name, chainName, resId, resIC, coords[0], coords[1], coords[2],
- resId, res.name, chainName, atom.name, modelIndex), file=file)
+ print(
+ line
+ % (
+ recordName,
+ atomIndex,
+ symbol,
+ atom.name,
+ res.name,
+ chainName,
+ resId,
+ resIC,
+ coords[0],
+ coords[1],
+ coords[2],
+ resId,
+ res.name,
+ chainName,
+ atom.name,
+ modelIndex,
+ ),
+ file=file,
+ )
posIndex += 1
atomIndex += 1
diff --git a/gufe/vendor/pdb_file/singelton.py b/gufe/vendor/pdb_file/singelton.py
index 0fb0b54c..fb7b7f4b 100644
--- a/gufe/vendor/pdb_file/singelton.py
+++ b/gufe/vendor/pdb_file/singelton.py
@@ -3,13 +3,15 @@
maintains the correctness of instance is instance even following
pickling/unpickling
"""
-class Singleton(object):
+
+
+class Singleton:
_inst = None
+
def __new__(cls):
if cls._inst is None:
- cls._inst = super(Singleton, cls).__new__(cls)
+ cls._inst = super().__new__(cls)
return cls._inst
def __reduce__(self):
return repr(self)
-
diff --git a/gufe/vendor/pdb_file/topology.py b/gufe/vendor/pdb_file/topology.py
index d83fa6da..5d5e0d88 100644
--- a/gufe/vendor/pdb_file/topology.py
+++ b/gufe/vendor/pdb_file/topology.py
@@ -28,50 +28,64 @@
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
-from __future__ import absolute_import
+
__author__ = "Peter Eastman"
__version__ = "1.0"
import os
-
-import numpy as np
-from copy import deepcopy
+import xml.etree.ElementTree as etree
from collections import namedtuple
+from copy import deepcopy
-import xml.etree.ElementTree as etree
+import numpy as np
+from openmm.unit import is_quantity, nanometers, sqrt
from .singelton import Singleton
-from openmm.unit import nanometers, sqrt, is_quantity
-
# Enumerated values for bond type
+
class Single(Singleton):
def __repr__(self):
- return 'Single'
+ return "Single"
+
+
Single = Single()
+
class Double(Singleton):
def __repr__(self):
- return 'Double'
+ return "Double"
+
+
Double = Double()
+
class Triple(Singleton):
def __repr__(self):
- return 'Triple'
+ return "Triple"
+
+
Triple = Triple()
+
class Aromatic(Singleton):
def __repr__(self):
- return 'Aromatic'
+ return "Aromatic"
+
+
Aromatic = Aromatic()
+
class Amide(Singleton):
def __repr__(self):
- return 'Amide'
+ return "Amide"
+
+
Amide = Amide()
-class Topology(object):
+
+class Topology:
"""Topology stores the topological information about a system.
The structure of a Topology object is similar to that of a PDB file. It consists of a set of Chains
@@ -98,27 +112,28 @@ def __repr__(self):
nres = self._numResidues
natom = self._numAtoms
nbond = len(self._bonds)
- return '<%s; %d chains, %d residues, %d atoms, %d bonds>' % (
- type(self).__name__, nchains, nres, natom, nbond)
+ return "<%s; %d chains, %d residues, %d atoms, %d bonds>" % (
+ type(self).__name__,
+ nchains,
+ nres,
+ natom,
+ nbond,
+ )
def getNumAtoms(self):
- """Return the number of atoms in the Topology.
- """
+ """Return the number of atoms in the Topology."""
return self._numAtoms
def getNumResidues(self):
- """Return the number of residues in the Topology.
- """
+ """Return the number of residues in the Topology."""
return self._numResidues
def getNumChains(self):
- """Return the number of chains in the Topology.
- """
+ """Return the number of chains in the Topology."""
return len(self._chains)
def getNumBonds(self):
- """Return the number of bonds in the Topology.
- """
+ """Return the number of bonds in the Topology."""
return len(self._bonds)
def addChain(self, id=None):
@@ -136,12 +151,12 @@ def addChain(self, id=None):
the newly created Chain
"""
if id is None:
- id = str(len(self._chains)+1)
+ id = str(len(self._chains) + 1)
chain = Chain(len(self._chains), self, id)
self._chains.append(chain)
return chain
- def addResidue(self, name, chain, id=None, insertionCode=''):
+ def addResidue(self, name, chain, id=None, insertionCode=""):
"""Create a new Residue and add it to the Topology.
Parameters
@@ -161,10 +176,10 @@ def addResidue(self, name, chain, id=None, insertionCode=''):
Residue
the newly created Residue
"""
- if len(chain._residues) > 0 and self._numResidues != chain._residues[-1].index+1:
- raise ValueError('All residues within a chain must be contiguous')
+ if len(chain._residues) > 0 and self._numResidues != chain._residues[-1].index + 1:
+ raise ValueError("All residues within a chain must be contiguous")
if id is None:
- id = str(self._numResidues+1)
+ id = str(self._numResidues + 1)
residue = Residue(name, self._numResidues, chain, id, insertionCode)
self._numResidues += 1
chain._residues.append(residue)
@@ -190,10 +205,10 @@ def addAtom(self, name, element, residue, id=None):
Atom
the newly created Atom
"""
- if len(residue._atoms) > 0 and self._numAtoms != residue._atoms[-1].index+1:
- raise ValueError('All atoms within a residue must be contiguous')
+ if len(residue._atoms) > 0 and self._numAtoms != residue._atoms[-1].index + 1:
+ raise ValueError("All atoms within a residue must be contiguous")
if id is None:
- id = str(self._numAtoms+1)
+ id = str(self._numAtoms + 1)
atom = Atom(name, element, self._numAtoms, residue, id)
self._numAtoms += 1
residue._atoms.append(atom)
@@ -223,15 +238,13 @@ def chains(self):
def residues(self):
"""Iterate over all Residues in the Topology."""
for chain in self._chains:
- for residue in chain._residues:
- yield residue
+ yield from chain._residues
def atoms(self):
"""Iterate over all Atoms in the Topology."""
for chain in self._chains:
for residue in chain._residues:
- for atom in residue._atoms:
- yield atom
+ yield from residue._atoms
def bonds(self):
"""Iterate over all bonds (each represented as a tuple of two Atoms) in the Topology."""
@@ -240,20 +253,28 @@ def bonds(self):
def getPeriodicBoxVectors(self):
"""Get the vectors defining the periodic box.
- The return value may be None if this Topology does not represent a periodic structure."""
+ The return value may be None if this Topology does not represent a periodic structure.
+ """
return self._periodicBoxVectors
def setPeriodicBoxVectors(self, vectors):
"""Set the vectors defining the periodic box."""
if vectors is not None:
if not is_quantity(vectors[0][0]):
- vectors = vectors*nanometers
- if vectors[0][1] != 0*nanometers or vectors[0][2] != 0*nanometers:
- raise ValueError("First periodic box vector must be parallel to x.");
- if vectors[1][2] != 0*nanometers:
- raise ValueError("Second periodic box vector must be in the x-y plane.");
- if vectors[0][0] <= 0*nanometers or vectors[1][1] <= 0*nanometers or vectors[2][2] <= 0*nanometers or vectors[0][0] < 2*abs(vectors[1][0]) or vectors[0][0] < 2*abs(vectors[2][0]) or vectors[1][1] < 2*abs(vectors[2][1]):
- raise ValueError("Periodic box vectors must be in reduced form.");
+ vectors = vectors * nanometers
+ if vectors[0][1] != 0 * nanometers or vectors[0][2] != 0 * nanometers:
+ raise ValueError("First periodic box vector must be parallel to x.")
+ if vectors[1][2] != 0 * nanometers:
+ raise ValueError("Second periodic box vector must be in the x-y plane.")
+ if (
+ vectors[0][0] <= 0 * nanometers
+ or vectors[1][1] <= 0 * nanometers
+ or vectors[2][2] <= 0 * nanometers
+ or vectors[0][0] < 2 * abs(vectors[1][0])
+ or vectors[0][0] < 2 * abs(vectors[2][0])
+ or vectors[1][1] < 2 * abs(vectors[2][1])
+ ):
+ raise ValueError("Periodic box vectors must be in reduced form.")
self._periodicBoxVectors = deepcopy(vectors)
def getUnitCellDimensions(self):
@@ -266,19 +287,24 @@ def getUnitCellDimensions(self):
xsize = self._periodicBoxVectors[0][0].value_in_unit(nanometers)
ysize = self._periodicBoxVectors[1][1].value_in_unit(nanometers)
zsize = self._periodicBoxVectors[2][2].value_in_unit(nanometers)
- return np.array([xsize, ysize, zsize])*nanometers
+ return np.array([xsize, ysize, zsize]) * nanometers
def setUnitCellDimensions(self, dimensions):
"""Set the dimensions of the crystallographic unit cell.
This method is an alternative to setPeriodicBoxVectors() for the case of a rectangular box. It sets
- the box vectors to be orthogonal to each other and to have the specified lengths."""
+ the box vectors to be orthogonal to each other and to have the specified lengths.
+ """
if dimensions is None:
self._periodicBoxVectors = None
else:
if is_quantity(dimensions):
dimensions = dimensions.value_in_unit(nanometers)
- self._periodicBoxVectors = (np.array([dimensions[0], 0, 0]), np.array([0, dimensions[1], 0]), np.array([0, 0, dimensions[2]]))*nanometers
+ self._periodicBoxVectors = (
+ np.array([dimensions[0], 0, 0]),
+ np.array([0, dimensions[1], 0]),
+ np.array([0, 0, dimensions[2]]),
+ ) * nanometers
@staticmethod
def loadBondDefinitions(file):
@@ -291,12 +317,18 @@ def loadBondDefinitions(file):
will be used for any PDB file loaded after this is called.
"""
tree = etree.parse(file)
- for residue in tree.getroot().findall('Residue'):
+ for residue in tree.getroot().findall("Residue"):
bonds = []
- Topology._standardBonds[residue.attrib['name']] = bonds
- for bond in residue.findall('Bond'):
- bonds.append((bond.attrib['from'], bond.attrib['to'], bond.attrib['type'], int(
- bond.attrib['order'])))
+ Topology._standardBonds[residue.attrib["name"]] = bonds
+ for bond in residue.findall("Bond"):
+ bonds.append(
+ (
+ bond.attrib["from"],
+ bond.attrib["to"],
+ bond.attrib["type"],
+ int(bond.attrib["order"]),
+ )
+ )
def createStandardBonds(self):
"""Create bonds based on the atom and residue names for all standard residue types.
@@ -307,7 +339,7 @@ def createStandardBonds(self):
if not Topology._hasLoadedStandardBonds:
# Load the standard bond definitions.
- Topology.loadBondDefinitions(os.path.join(os.path.dirname(__file__), 'data', 'residues.xml'))
+ Topology.loadBondDefinitions(os.path.join(os.path.dirname(__file__), "data", "residues.xml"))
Topology._hasLoadedStandardBonds = True
for chain in self._chains:
# First build a map of atom names to atoms.
@@ -325,52 +357,57 @@ def createStandardBonds(self):
name = chain._residues[i].name
if name in Topology._standardBonds:
for bond in Topology._standardBonds[name]:
- if bond[0].startswith('-') and i > 0:
- fromResidue = i-1
+ if bond[0].startswith("-") and i > 0:
+ fromResidue = i - 1
fromAtom = bond[0][1:]
- elif bond[0].startswith('+') and i 0:
- toResidue = i-1
+ if bond[1].startswith("-") and i > 0:
+ toResidue = i - 1
toAtom = bond[1][1:]
- elif bond[1].startswith('+') and i ND1=CE1-NE2-HE2 - avoid "charged" resonance structure
- bond_atoms = (fromAtom, toAtom)
- if(name == "HIS" and "CE1" in bond_atoms and any([N in bond_atoms for N in ["ND1", "NE2"]])):
+ bond_atoms = (fromAtom, toAtom)
+ if name == "HIS" and "CE1" in bond_atoms and any([N in bond_atoms for N in ["ND1", "NE2"]]):
atoms = atomMaps[i]
ND1_protonated = "HD1" in atoms
NE2_protonated = "HE2" in atoms
-
- if(ND1_protonated and not NE2_protonated): # HD1-ND1-CE1=ND2
- if("ND1" in bond_atoms):
- bond_order = 1
+
+ if ND1_protonated and not NE2_protonated: # HD1-ND1-CE1=ND2
+ if "ND1" in bond_atoms:
+ bond_order = 1
else:
- bond_order = 2
- elif(not ND1_protonated and NE2_protonated): # ND1=CE1-NE2-HE2
- if("ND1" in bond_atoms):
- bond_order = 2
+ bond_order = 2
+ elif not ND1_protonated and NE2_protonated: # ND1=CE1-NE2-HE2
+ if "ND1" in bond_atoms:
+ bond_order = 2
else:
- bond_order = 1
- else: # does not matter if doubly or none protonated.
+ bond_order = 1
+ else: # does not matter if doubly or none protonated.
pass
- self.addBond(atomMaps[fromResidue][fromAtom], atomMaps[toResidue][toAtom], type=bond_type, order=bond_order)
+ self.addBond(
+ atomMaps[fromResidue][fromAtom],
+ atomMaps[toResidue][toAtom],
+ type=bond_type,
+ order=bond_order,
+ )
def createDisulfideBonds(self, positions):
"""Identify disulfide bonds based on proximity and add them to the
@@ -381,30 +418,31 @@ def createDisulfideBonds(self, positions):
positions : list
The list of atomic positions based on which to identify bonded atoms
"""
+
def isCyx(res):
names = [atom.name for atom in res._atoms]
- return 'SG' in names and 'HG' not in names
+ return "SG" in names and "HG" not in names
+
# This function is used to prevent multiple di-sulfide bonds from being
# assigned to a given atom.
def isDisulfideBonded(atom):
- for b in self._bonds:
- if (atom in b and b[0].name == 'SG' and
- b[1].name == 'SG'):
- return True
+ for b in self._bonds:
+ if atom in b and b[0].name == "SG" and b[1].name == "SG":
+ return True
- return False
+ return False
- cyx = [res for res in self.residues() if res.name == 'CYS' and isCyx(res)]
+ cyx = [res for res in self.residues() if res.name == "CYS" and isCyx(res)]
atomNames = [[atom.name for atom in res._atoms] for res in cyx]
for i in range(len(cyx)):
- sg1 = cyx[i]._atoms[atomNames[i].index('SG')]
+ sg1 = cyx[i]._atoms[atomNames[i].index("SG")]
pos1 = positions[sg1.index]
- candidate_distance, candidate_atom = 0.3*nanometers, None
+ candidate_distance, candidate_atom = 0.3 * nanometers, None
for j in range(i):
- sg2 = cyx[j]._atoms[atomNames[j].index('SG')]
+ sg2 = cyx[j]._atoms[atomNames[j].index("SG")]
pos2 = positions[sg2.index]
- delta = [x-y for (x,y) in zip(pos1, pos2)]
- distance = sqrt(delta[0]*delta[0] + delta[1]*delta[1] + delta[2]*delta[2])
+ delta = [x - y for (x, y) in zip(pos1, pos2)]
+ distance = sqrt(delta[0] * delta[0] + delta[1] * delta[1] + delta[2] * delta[2])
if distance < candidate_distance and not isDisulfideBonded(sg2):
candidate_distance = distance
candidate_atom = sg2
@@ -412,8 +450,10 @@ def isDisulfideBonded(atom):
if candidate_atom:
self.addBond(sg1, candidate_atom, type="Single", order=1)
-class Chain(object):
+
+class Chain:
"""A Chain object represents a chain within a Topology."""
+
def __init__(self, index, topology, id):
"""Construct a new Chain. You should call addChain() on the Topology instead of calling this directly."""
## The index of the Chain within its Topology
@@ -431,8 +471,7 @@ def residues(self):
def atoms(self):
"""Iterate over all Atoms in the Chain."""
for residue in self._residues:
- for atom in residue._atoms:
- yield atom
+ yield from residue._atoms
def __len__(self):
return len(self._residues)
@@ -440,8 +479,10 @@ def __len__(self):
def __repr__(self):
return "" % self.index
-class Residue(object):
+
+class Residue:
"""A Residue object represents a residue within a Topology."""
+
def __init__(self, name, index, chain, id, insertionCode):
"""Construct a new Residue. You should call addResidue() on the Topology instead of calling this directly."""
## The name of the Residue
@@ -462,23 +503,28 @@ def atoms(self):
def bonds(self):
"""Iterate over all Bonds involving any atom in this residue."""
- return ( bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) or (bond[1] in self._atoms)) )
+ return (bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) or (bond[1] in self._atoms)))
def internal_bonds(self):
"""Iterate over all internal Bonds."""
- return ( bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) and (bond[1] in self._atoms)) )
+ return (bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) and (bond[1] in self._atoms)))
def external_bonds(self):
"""Iterate over all Bonds to external atoms."""
- return ( bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) != (bond[1] in self._atoms)) )
+ return (bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) != (bond[1] in self._atoms)))
def __len__(self):
return len(self._atoms)
def __repr__(self):
- return "" % (self.index, self.name, self.chain.index)
+ return "" % (
+ self.index,
+ self.name,
+ self.chain.index,
+ )
+
-class Atom(object):
+class Atom:
"""An Atom object represents an atom within a Topology."""
def __init__(self, name, element, index, residue, id):
@@ -495,17 +541,25 @@ def __init__(self, name, element, index, residue, id):
self.id = id
def __repr__(self):
- return "" % (self.index, self.name, self.residue.chain.index, self.residue.index, self.residue.name)
+ return "" % (
+ self.index,
+ self.name,
+ self.residue.chain.index,
+ self.residue.index,
+ self.residue.name,
+ )
-class Bond(namedtuple('Bond', ['atom1', 'atom2'])):
+
+class Bond(namedtuple("Bond", ["atom1", "atom2"])):
"""A Bond object represents a bond between two Atoms within a Topology.
This class extends tuple, and may be interpreted as a 2 element tuple of Atom objects.
- It also has fields that can optionally be used to describe the bond order and type of bond."""
+ It also has fields that can optionally be used to describe the bond order and type of bond.
+ """
def __new__(cls, atom1, atom2, type=None, order=None):
"""Create a new Bond. You should call addBond() on the Topology instead of calling this directly."""
- bond = super(Bond, cls).__new__(cls, atom1, atom2)
+ bond = super().__new__(cls, atom1, atom2)
bond.type = type
bond.order = order
return bond
@@ -526,9 +580,9 @@ def __deepcopy__(self, memo):
return Bond(self[0], self[1], self.type, self.order)
def __repr__(self):
- s = "Bond(%s, %s" % (self[0], self[1])
+ s = f"Bond({self[0]}, {self[1]}"
if self.type is not None:
- s = "%s, type=%s" % (s, self.type)
+ s = f"{s}, type={self.type}"
if self.order is not None:
s = "%s, order=%d" % (s, self.order)
s += ")"
diff --git a/gufe/vendor/pdb_file/unitcell.py b/gufe/vendor/pdb_file/unitcell.py
index e4068707..bc65cec7 100644
--- a/gufe/vendor/pdb_file/unitcell.py
+++ b/gufe/vendor/pdb_file/unitcell.py
@@ -28,36 +28,43 @@
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
-from __future__ import absolute_import
+
__author__ = "Peter Eastman"
__version__ = "1.0"
-import numpy as np
-from openmm.unit import nanometers, is_quantity, norm, dot, radians
import math
+import numpy as np
+from openmm.unit import dot, is_quantity, nanometers, norm, radians
+
def computePeriodicBoxVectors(a_length, b_length, c_length, alpha, beta, gamma):
"""Convert lengths and angles to periodic box vectors.
-
+
Lengths should be given in nanometers and angles in radians (or as Quantity
instances)
"""
- if is_quantity(a_length): a_length = a_length.value_in_unit(nanometers)
- if is_quantity(b_length): b_length = b_length.value_in_unit(nanometers)
- if is_quantity(c_length): c_length = c_length.value_in_unit(nanometers)
- if is_quantity(alpha): alpha = alpha.value_in_unit(radians)
- if is_quantity(beta): beta = beta.value_in_unit(radians)
- if is_quantity(gamma): gamma = gamma.value_in_unit(radians)
+ if is_quantity(a_length):
+ a_length = a_length.value_in_unit(nanometers)
+ if is_quantity(b_length):
+ b_length = b_length.value_in_unit(nanometers)
+ if is_quantity(c_length):
+ c_length = c_length.value_in_unit(nanometers)
+ if is_quantity(alpha):
+ alpha = alpha.value_in_unit(radians)
+ if is_quantity(beta):
+ beta = beta.value_in_unit(radians)
+ if is_quantity(gamma):
+ gamma = gamma.value_in_unit(radians)
# Compute the vectors.
a = [a_length, 0, 0]
- b = [b_length*math.cos(gamma), b_length*math.sin(gamma), 0]
- cx = c_length*math.cos(beta)
- cy = c_length*(math.cos(alpha)-math.cos(beta)*math.cos(gamma))/math.sin(gamma)
- cz = math.sqrt(c_length*c_length-cx*cx-cy*cy)
+ b = [b_length * math.cos(gamma), b_length * math.sin(gamma), 0]
+ cx = c_length * math.cos(beta)
+ cy = c_length * (math.cos(alpha) - math.cos(beta) * math.cos(gamma)) / math.sin(gamma)
+ cz = math.sqrt(c_length * c_length - cx * cx - cy * cy)
c = [cx, cy, cz]
# If any elements are very close to 0, set them to exactly 0.
@@ -75,13 +82,14 @@ def computePeriodicBoxVectors(a_length, b_length, c_length, alpha, beta, gamma):
# Make sure they're in the reduced form required by OpenMM.
- c = c - b*round(c[1]/b[1])
- c = c - a*round(c[0]/a[0])
- b = b - a*round(b[0]/a[0])
- return (a, b, c)*nanometers
+ c = c - b * round(c[1] / b[1])
+ c = c - a * round(c[0] / a[0])
+ b = b - a * round(b[0] / a[0])
+ return (a, b, c) * nanometers
+
def reducePeriodicBoxVectors(periodicBoxVectors):
- """ Reduces the representation of the PBC. periodicBoxVectors is expected to
+ """Reduces the representation of the PBC. periodicBoxVectors is expected to
be an unpackable iterable of length-3 iterables
"""
if is_quantity(periodicBoxVectors):
@@ -92,12 +100,13 @@ def reducePeriodicBoxVectors(periodicBoxVectors):
b = np.array(b)
c = np.array(c)
- c = c - b*round(c[1]/b[1])
- c = c - a*round(c[0]/a[0])
- b = b - a*round(b[0]/a[0])
+ c = c - b * round(c[1] / b[1])
+ c = c - a * round(c[0] / a[0])
+ b = b - a * round(b[0] / a[0])
return (a, b, c) * nanometers
+
def computeLengthsAndAngles(periodicBoxVectors):
"""Convert periodic box vectors to lengths and angles.
@@ -110,7 +119,7 @@ def computeLengthsAndAngles(periodicBoxVectors):
a_length = norm(a)
b_length = norm(b)
c_length = norm(c)
- alpha = math.acos(dot(b, c)/(b_length*c_length))
- beta = math.acos(dot(c, a)/(c_length*a_length))
- gamma = math.acos(dot(a, b)/(a_length*b_length))
+ alpha = math.acos(dot(b, c) / (b_length * c_length))
+ beta = math.acos(dot(c, a) / (c_length * a_length))
+ gamma = math.acos(dot(a, b) / (a_length * b_length))
return (a_length, b_length, c_length, alpha, beta, gamma)
diff --git a/gufe/visualization/mapping_visualization.py b/gufe/visualization/mapping_visualization.py
index 0b6c02be..ac110b8a 100644
--- a/gufe/visualization/mapping_visualization.py
+++ b/gufe/visualization/mapping_visualization.py
@@ -1,12 +1,11 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe
-from typing import Any, Collection, Optional
+from collections.abc import Collection
from itertools import chain
+from typing import Any, Optional
from rdkit import Chem
-from rdkit.Chem import Draw
-from rdkit.Chem import AllChem
-
+from rdkit.Chem import AllChem, Draw
# highlight core element changes differently from unique atoms
# RGBA color value needs to be between 0 and 1, so divide by 255
@@ -14,9 +13,7 @@
BLUE = (0.0, 90 / 255, 181 / 255, 1.0)
-def _match_elements(
- mol1: Chem.Mol, idx1: int, mol2: Chem.Mol, idx2: int
-) -> bool:
+def _match_elements(mol1: Chem.Mol, idx1: int, mol2: Chem.Mol, idx2: int) -> bool:
"""
Convenience method to check if elements between two molecules (molA
and molB) are the same.
@@ -42,9 +39,7 @@ def _match_elements(
return elem_mol1 == elem_mol2
-def _get_unique_bonds_and_atoms(
- mapping: dict[int, int], mol1: Chem.Mol, mol2: Chem.Mol
-) -> dict:
+def _get_unique_bonds_and_atoms(mapping: dict[int, int], mol1: Chem.Mol, mol2: Chem.Mol) -> dict:
"""
Given an input mapping, returns new atoms, element changes, and
involved bonds.
@@ -143,13 +138,14 @@ def _draw_molecules(
if d2d is None:
# select default layout based on number of molecules
- grid_x, grid_y = {1: (1, 1), 2: (2, 1), }[len(mols)]
- d2d = Draw.rdMolDraw2D.MolDraw2DCairo(
- grid_x * 300, grid_y * 300, 300, 300)
+ grid_x, grid_y = {
+ 1: (1, 1),
+ 2: (2, 1),
+ }[len(mols)]
+ d2d = Draw.rdMolDraw2D.MolDraw2DCairo(grid_x * 300, grid_y * 300, 300, 300)
# get molecule name labels
- labels = [m.GetProp("ofe-name") if(m.HasProp("ofe-name"))
- else "" for m in mols]
+ labels = [m.GetProp("ofe-name") if (m.HasProp("ofe-name")) else "" for m in mols]
# squash to 2D
copies = [Chem.Mol(mol) for mol in mols]
@@ -158,10 +154,7 @@ def _draw_molecules(
# mol alignments if atom_mapping present
for (i, j), atomMap in atom_mapping.items():
- AllChem.AlignMol(
- copies[j], copies[i],
- atomMap=[(k, v) for v, k in atomMap.items()]
- )
+ AllChem.AlignMol(copies[j], copies[i], atomMap=[(k, v) for v, k in atomMap.items()])
# standard settings for our visualization
d2d.drawOptions().useBWAtomPalette()
@@ -180,9 +173,7 @@ def _draw_molecules(
return d2d.GetDrawingText()
-def draw_mapping(
- mol1_to_mol2: dict[int, int], mol1: Chem.Mol, mol2: Chem.Mol, d2d=None
-):
+def draw_mapping(mol1_to_mol2: dict[int, int], mol1: Chem.Mol, mol2: Chem.Mol, d2d=None):
"""
Method to visualise the atom map correspondence between two rdkit
molecules given an input mapping.
@@ -229,7 +220,6 @@ def draw_mapping(
atom_colors = [at1_colors, at2_colors]
bond_colors = [bd1_colors, bd2_colors]
-
return _draw_molecules(
d2d,
[mol1, mol2],
@@ -270,14 +260,15 @@ def draw_one_molecule_mapping(mol1_to_mol2, mol1, mol2, d2d=None):
atom_colors = [{at: BLUE for at in uniques["elements"]}]
bond_colors = [{bd: BLUE for bd in uniques["bond_changes"]}]
- return _draw_molecules(d2d,
- [mol1],
- atoms_list,
- bonds_list,
- atom_colors,
- bond_colors,
- RED,
- )
+ return _draw_molecules(
+ d2d,
+ [mol1],
+ atoms_list,
+ bonds_list,
+ atom_colors,
+ bond_colors,
+ RED,
+ )
def draw_unhighlighted_molecule(mol, d2d=None):
diff --git a/pyproject.toml b/pyproject.toml
index d68b3488..f7833013 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -77,6 +77,26 @@ dirty = "{base_version}+{distance}.{vcs}{rev}.dirty"
distance-dirty = "{base_version}+{distance}.{vcs}{rev}.dirty"
[tool.versioningit.vcs]
-method = "git"
+method = "git"
match = ["*"]
default-tag = "0.0.0"
+
+[tool.black]
+line-length = 120
+
+[tool.isort]
+multi_line_output = 3
+include_trailing_comma = true
+force_grid_wrap = 0
+use_parentheses = true
+line_length = 120
+profile = "black"
+known_first_party = ["gufe"]
+
+[tool.interrogate]
+fail-under = 0
+ignore-regex = ["^get$", "^mock_.*", ".*BaseClass.*"]
+# possible values for verbose: 0 (minimal output), 1 (-v), 2 (-vv)
+verbose = 2
+color = true
+exclude = ["build", "docs"]