Skip to content

Commit

Permalink
Merge pull request #622 from ISISNeutronMuon/chi/md-traj-converter
Browse files Browse the repository at this point in the history
Add MDTraj converter and update to atom mapping code
  • Loading branch information
MBartkowiakSTFC authored Jan 3, 2025
2 parents 405b94b + 2a64111 commit f134bd3
Show file tree
Hide file tree
Showing 29 changed files with 918 additions and 207 deletions.
87 changes: 81 additions & 6 deletions MDANSE/Src/MDANSE/Framework/AtomMapping/atom_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,67 @@

class AtomLabel:

def __init__(self, atm_label, **kwargs):
self.atm_label = atm_label
def __init__(self, atm_label: str, **kwargs):
"""Creates an atom label object which is used for atom mapping
and atom type guessing.
Parameters
----------
atm_label : str
The main atom label.
kwargs
The other atom label.
"""
# use translations since it's faster than the alternative
# methods as of writing e.g. re.sub
translation = str.maketrans("", "", ";=")
self.atm_label = atm_label.translate(translation)
self.grp_label = f""
if kwargs:
for k, v in kwargs.items():
self.grp_label += f"{k}={v};"
self.grp_label += f"{k}={str(v).translate(translation)};"
self.grp_label = self.grp_label[:-1]
self.mass = kwargs.get("mass", None)
if self.mass is not None:
self.mass = float(self.mass)

def __eq__(self, other: object) -> bool:
"""Used to check if atom labels are equal.
Parameters
----------
other : AtomLabel
The other atom label to compare against.
Returns
-------
bool
True if all attributes are equal.
Raises
------
AssertionError
If the other object is not an AtomLabel.
"""
if not isinstance(other, AtomLabel):
AssertionError(f"{other} should be an instance of AtomLabel.")
if self.grp_label == other.grp_label and self.atm_label == other.atm_label:
if (
self.grp_label == other.grp_label
and self.atm_label == other.atm_label
and self.mass == other.mass
):
return True
else:
return False

def __hash__(self) -> int:
"""
Returns
-------
int
A hash of the object in its current state.
"""
return hash((self.atm_label, self.grp_label, self.mass))


def guess_element(atm_label: str, mass: Union[float, int, None] = None) -> str:
Expand Down Expand Up @@ -179,6 +224,30 @@ def fill_remaining_labels(
mapping[grp_label][atm_label] = guess_element(atm_label, label.mass)


def mapping_to_labels(mapping: dict[str, dict[str, str]]) -> list[AtomLabel]:
"""Converts the mapping back into a list of labels.
Parameters
----------
mapping : dict[str, dict[str, str]]
The atom mapping dictionary.
Returns
-------
list[AtomLabel]
List of atom labels from the mapping.
"""
labels = []
for grp_label, atm_map in mapping.items():
kwargs = {}
if grp_label:
for k, v in [i.split("=") for i in grp_label.split(";")]:
kwargs[k] = v
for atm_label in atm_map.keys():
labels.append(AtomLabel(atm_label, **kwargs))
return labels


def check_mapping_valid(mapping: dict[str, dict[str, str]], labels: list[AtomLabel]):
"""Given a list of labels check that the mapping is valid.
Expand All @@ -194,11 +263,17 @@ def check_mapping_valid(mapping: dict[str, dict[str, str]], labels: list[AtomLab
bool
True if the mapping is valid.
"""
pattern = re.compile("^([A-Za-z]\w*=[^=;]+(;[A-Za-z]\w*=[^=;]+)*)*$")
if not all([pattern.match(grp_label) for grp_label in mapping.keys()]):
return False

if set(mapping_to_labels(mapping)) != set(labels):
return False

for label in labels:
grp_label = label.grp_label
atm_label = label.atm_label
if grp_label not in mapping or atm_label not in mapping[grp_label]:
return False
if mapping[grp_label][atm_label] not in ATOMS_DATABASE:
return False

return True
18 changes: 8 additions & 10 deletions MDANSE/Src/MDANSE/Framework/Configurators/ASEFileConfigurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
from typing import Iterable

from ase.io import iread, read
from ase.io.trajectory import Trajectory as ASETrajectory

Expand All @@ -37,16 +39,12 @@ def parse(self):

self["element_list"] = first_frame.get_chemical_symbols()

def get_atom_labels(self) -> list[AtomLabel]:
def atom_labels(self) -> Iterable[AtomLabel]:
"""
Returns
-------
list[AtomLabel]
An ordered list of atom labels.
Yields
------
AtomLabel
An atom label.
"""
labels = []
for atm_label in self["element_list"]:
label = AtomLabel(atm_label)
if label not in labels:
labels.append(label)
return labels
yield AtomLabel(atm_label)
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def configure(self, value) -> None:

file_configurator = self._configurable[self._dependencies["input_file"]]
if not file_configurator._valid:
self.error_status = "Input file not selected."
self.error_status = "Input file not selected or valid."
return

labels = file_configurator.get_atom_labels()
labels = file_configurator.labels
try:
fill_remaining_labels(value, labels)
except AttributeError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
from typing import Iterable
import re

import numpy as np

from MDANSE.Core.Error import Error
from MDANSE.Framework.AtomMapping import AtomLabel

from .FileWithAtomDataConfigurator import FileWithAtomDataConfigurator
from MDANSE.MLogging import LOG

Expand Down Expand Up @@ -186,16 +187,12 @@ def parse(self):
except:
LOG.error(f"LAMMPS ConfigFileConfigurator failed to find a unit cell")

def get_atom_labels(self) -> list[AtomLabel]:
def atom_labels(self) -> Iterable[AtomLabel]:
"""
Returns
-------
list[AtomLabel]
An ordered list of atom labels.
Yields
------
AtomLabel
An atom label.
"""
labels = []
for idx, mass in self["elements"]:
label = AtomLabel(str(idx), mass=mass)
if label not in labels:
labels.append(label)
return labels
yield AtomLabel(str(idx), mass=mass)
18 changes: 7 additions & 11 deletions MDANSE/Src/MDANSE/Framework/Configurators/FieldFileConfigurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
from typing import Iterable
import re

import numpy as np
Expand All @@ -23,7 +24,6 @@
)
from MDANSE.Core.Error import Error
from MDANSE.Framework.AtomMapping import get_element_from_mapping, AtomLabel

from .FileWithAtomDataConfigurator import FileWithAtomDataConfigurator


Expand Down Expand Up @@ -114,20 +114,16 @@ def parse(self):

first = last + 1

def get_atom_labels(self) -> list[AtomLabel]:
def atom_labels(self) -> Iterable[AtomLabel]:
"""
Returns
-------
list[AtomLabel]
An ordered list of atom labels.
Yields
------
AtomLabel
An atom label.
"""
labels = []
for mol_name, _, atomic_contents, masses, _ in self["molecules"]:
for atm_label, mass in zip(atomic_contents, masses):
label = AtomLabel(atm_label, molecule=mol_name, mass=mass)
if label not in labels:
labels.append(label)
return labels
yield AtomLabel(atm_label, molecule=mol_name, mass=mass)

def get_atom_charges(self) -> np.ndarray:
"""Returns an array of partial electric charges
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
from typing import Iterable
from abc import abstractmethod
import traceback

Expand All @@ -37,13 +38,27 @@ def configure(self, filepath: str) -> None:
self.parse()
except Exception as e:
self.error_status = f"File parsing error {e}: {traceback.format_exc()}"
return

self.labels = self.unique_labels()
if len(self.labels) == 0:
self.error_status = f"Unable to generate atom labels"
return

@abstractmethod
def parse(self) -> None:
"""Parse the file."""
pass

@abstractmethod
def get_atom_labels(self) -> list[AtomLabel]:
"""Return the atoms labels in the file."""
pass
def atom_labels(self) -> Iterable[AtomLabel]:
"""Yields atom labels"""

def unique_labels(self) -> list[AtomLabel]:
"""
Returns
-------
list[AtomLabel]
An ordered list of atom labels.
"""
return list(set(self.atom_labels()))
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
import ast
import os
from typing import Union

import MDAnalysis as mda

from MDANSE import PLATFORM
from MDANSE.Framework.Configurators.IConfigurator import IConfigurator
from .MultiInputFileConfigurator import MultiInputFileConfigurator


class CoordinateFileConfigurator(IConfigurator):
class MDAnalysisCoordinateFileConfigurator(MultiInputFileConfigurator):

_default = ("", "AUTO")

Expand All @@ -41,46 +39,7 @@ def configure(self, setting: tuple[Union[str, list], str]):
string of the coordinate file format.
"""
values, format = setting

self["values"] = self._default
self._original_input = values

if type(values) is str:
if values:
try:
values = ast.literal_eval(values)
except (SyntaxError, ValueError) as e:
self.error_status = f"Unable to evaluate string: {e}"
return
if type(values) is not list:
self.error_status = (
f"Input values should be able to be evaluated as a list"
)
return
else:
values = []

if type(values) is list:
if not all([type(value) is str for value in values]):
self.error_status = f"Input values should be a list of str"
return
else:
self.error_status = f"Input values should be able to be evaluated as a list"
return

values = [PLATFORM.get_path(value) for value in values]

none_exist = []
for value in values:
if not os.path.isfile(value):
none_exist.append(value)

if none_exist:
self.error_status = f"The files {', '.join(none_exist)} do not exist."
return

self["values"] = values
self["filenames"] = values
super().configure(values)

if format == "AUTO" or not self["filenames"]:
self["format"] = None
Expand All @@ -95,15 +54,15 @@ def configure(self, setting: tuple[Union[str, list], str]):
if topology_configurator._valid:
try:
if len(self["filenames"]) <= 1 or self["format"] is None:
traj = mda.Universe(
_ = mda.Universe(
topology_configurator["filename"],
*self["filenames"],
format=self["format"],
topology_format=topology_configurator["format"],
).trajectory
else:
coord_files = [(i, self["format"]) for i in self["filenames"]]
traj = mda.Universe(
_ = mda.Universe(
topology_configurator["filename"],
coord_files,
topology_format=topology_configurator["format"],
Expand All @@ -116,8 +75,6 @@ def configure(self, setting: tuple[Union[str, list], str]):
self.error_status = "Requires valid topology file."
return

self.error_status = "OK"

@property
def wildcard(self):
return self._wildcard
Expand Down
Loading

0 comments on commit f134bd3

Please sign in to comment.