Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MDTraj converter and update to atom mapping code #622

Merged
merged 33 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
670461b
MDTraj converter initial commit.
ChiCheng45 Dec 11, 2024
ba9d268
update the MDTraj atom labels, update the MDTraj widgets, refactored …
ChiCheng45 Dec 12, 2024
3ff02f7
update atom mapping so that extra mapping are not valid.
ChiCheng45 Dec 16, 2024
3df9f9d
fixed atom mapping check
ChiCheng45 Dec 16, 2024
e6b04f1
fixed atom mapping check
ChiCheng45 Dec 16, 2024
97651b3
fixed topology file config and improved atom mapping performance
ChiCheng45 Dec 16, 2024
544a5b5
refactored mapping check
ChiCheng45 Dec 16, 2024
290aaa8
speed up atom mapping functions
ChiCheng45 Dec 16, 2024
a2581db
refactored FileWithAtomDataConfigurator and improved performance for …
ChiCheng45 Dec 17, 2024
062c60c
added docstring to atom label and correct __eq__ method
ChiCheng45 Dec 17, 2024
dec2b34
corrected comment
ChiCheng45 Dec 17, 2024
99ea07f
improve atom mapping dialog performance
ChiCheng45 Dec 17, 2024
2f37c30
applied black
ChiCheng45 Dec 17, 2024
ac10b41
updated mdtraj converter so that multiple coordinate files can be used.
ChiCheng45 Dec 17, 2024
6f51462
fixed topology file configurator when no atom labels can be generated
ChiCheng45 Dec 17, 2024
dadd5ba
fixed mdtraj topology configurator
ChiCheng45 Dec 17, 2024
3e2c5df
refactored atom mapping code
ChiCheng45 Dec 18, 2024
f60c08a
refactored atom mapping code
ChiCheng45 Dec 18, 2024
085f6a8
refactored atom mapping code
ChiCheng45 Dec 18, 2024
9258ca0
mdanalysis atom mapping small performance improvement
ChiCheng45 Dec 18, 2024
6016d68
add mdtraj dependency
ChiCheng45 Dec 18, 2024
7630ddf
fixed unit test
ChiCheng45 Dec 18, 2024
0ca5930
fixed atom mapping regex and mapping to labels
ChiCheng45 Dec 18, 2024
044a782
add mdtraj unit test
ChiCheng45 Dec 18, 2024
bbf9d12
fixed test name
ChiCheng45 Dec 18, 2024
8ad1698
reduced number of calls to unique labels
ChiCheng45 Dec 19, 2024
7c4549c
refactored MDAnalysis time step widget
ChiCheng45 Jan 2, 2025
2b44a05
refactor mdanalysis time step widget and update mdtraj converter to i…
ChiCheng45 Jan 2, 2025
4f807a9
Merge branch 'protos' into chi/md-traj-converter
ChiCheng45 Jan 2, 2025
8fe98ee
fix encoding bug
ChiCheng45 Jan 2, 2025
93aeac3
applied black
ChiCheng45 Jan 2, 2025
3e22958
fixed MDTraj timestep
ChiCheng45 Jan 3, 2025
2a64111
MdTraj atom mapping fix
ChiCheng45 Jan 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 81 additions & 6 deletions MDANSE/Src/MDANSE/Framework/AtomMapping/atom_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,67 @@

class AtomLabel:

def __init__(self, atm_label, **kwargs):
self.atm_label = atm_label
def __init__(self, atm_label: str, **kwargs):
"""Creates an atom label object which is used for atom mapping
and atom type guessing.

Parameters
----------
atm_label : str
The main atom label.
kwargs
The other atom label.
"""
# use translations since it's faster than the alternative
# methods as of writing e.g. re.sub
translation = str.maketrans("", "", ";=")
self.atm_label = atm_label.translate(translation)
self.grp_label = f""
if kwargs:
for k, v in kwargs.items():
self.grp_label += f"{k}={v};"
self.grp_label += f"{k}={str(v).translate(translation)};"
self.grp_label = self.grp_label[:-1]
self.mass = kwargs.get("mass", None)
if self.mass is not None:
self.mass = float(self.mass)

def __eq__(self, other: object) -> bool:
"""Used to check if atom labels are equal.

Parameters
----------
other : AtomLabel
The other atom label to compare against.

Returns
-------
bool
True if all attributes are equal.

Raises
------
AssertionError
If the other object is not an AtomLabel.
"""
if not isinstance(other, AtomLabel):
AssertionError(f"{other} should be an instance of AtomLabel.")
if self.grp_label == other.grp_label and self.atm_label == other.atm_label:
if (
self.grp_label == other.grp_label
and self.atm_label == other.atm_label
and self.mass == other.mass
):
return True
else:
return False

def __hash__(self) -> int:
"""
Returns
-------
int
A hash of the object in its current state.
"""
return hash((self.atm_label, self.grp_label, self.mass))


def guess_element(atm_label: str, mass: Union[float, int, None] = None) -> str:
Expand Down Expand Up @@ -179,6 +224,30 @@ def fill_remaining_labels(
mapping[grp_label][atm_label] = guess_element(atm_label, label.mass)


def mapping_to_labels(mapping: dict[str, dict[str, str]]) -> list[AtomLabel]:
"""Converts the mapping back into a list of labels.

Parameters
----------
mapping : dict[str, dict[str, str]]
The atom mapping dictionary.

Returns
-------
list[AtomLabel]
List of atom labels from the mapping.
"""
labels = []
for grp_label, atm_map in mapping.items():
kwargs = {}
if grp_label:
for k, v in [i.split("=") for i in grp_label.split(";")]:
kwargs[k] = v
for atm_label in atm_map.keys():
labels.append(AtomLabel(atm_label, **kwargs))
return labels


def check_mapping_valid(mapping: dict[str, dict[str, str]], labels: list[AtomLabel]):
"""Given a list of labels check that the mapping is valid.

Expand All @@ -194,11 +263,17 @@ def check_mapping_valid(mapping: dict[str, dict[str, str]], labels: list[AtomLab
bool
True if the mapping is valid.
"""
pattern = re.compile("^([A-Za-z]\w*=[^=;]+(;[A-Za-z]\w*=[^=;]+)*)*$")
if not all([pattern.match(grp_label) for grp_label in mapping.keys()]):
return False

if set(mapping_to_labels(mapping)) != set(labels):
return False

for label in labels:
grp_label = label.grp_label
atm_label = label.atm_label
if grp_label not in mapping or atm_label not in mapping[grp_label]:
return False
if mapping[grp_label][atm_label] not in ATOMS_DATABASE:
return False

return True
18 changes: 8 additions & 10 deletions MDANSE/Src/MDANSE/Framework/Configurators/ASEFileConfigurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
from typing import Iterable

from ase.io import iread, read
from ase.io.trajectory import Trajectory as ASETrajectory

Expand All @@ -37,16 +39,12 @@ def parse(self):

self["element_list"] = first_frame.get_chemical_symbols()

def get_atom_labels(self) -> list[AtomLabel]:
def atom_labels(self) -> Iterable[AtomLabel]:
"""
Returns
-------
list[AtomLabel]
An ordered list of atom labels.
Yields
------
AtomLabel
An atom label.
"""
labels = []
for atm_label in self["element_list"]:
label = AtomLabel(atm_label)
if label not in labels:
labels.append(label)
return labels
yield AtomLabel(atm_label)
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def configure(self, value) -> None:

file_configurator = self._configurable[self._dependencies["input_file"]]
if not file_configurator._valid:
self.error_status = "Input file not selected."
self.error_status = "Input file not selected or valid."
return

labels = file_configurator.get_atom_labels()
labels = file_configurator.labels
try:
fill_remaining_labels(value, labels)
except AttributeError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
from typing import Iterable
import re

import numpy as np

from MDANSE.Core.Error import Error
from MDANSE.Framework.AtomMapping import AtomLabel

from .FileWithAtomDataConfigurator import FileWithAtomDataConfigurator
from MDANSE.MLogging import LOG

Expand Down Expand Up @@ -186,16 +187,12 @@ def parse(self):
except:
LOG.error(f"LAMMPS ConfigFileConfigurator failed to find a unit cell")

def get_atom_labels(self) -> list[AtomLabel]:
def atom_labels(self) -> Iterable[AtomLabel]:
"""
Returns
-------
list[AtomLabel]
An ordered list of atom labels.
Yields
------
AtomLabel
An atom label.
"""
labels = []
for idx, mass in self["elements"]:
label = AtomLabel(str(idx), mass=mass)
if label not in labels:
labels.append(label)
return labels
yield AtomLabel(str(idx), mass=mass)
18 changes: 7 additions & 11 deletions MDANSE/Src/MDANSE/Framework/Configurators/FieldFileConfigurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
from typing import Iterable
import re

import numpy as np
Expand All @@ -23,7 +24,6 @@
)
from MDANSE.Core.Error import Error
from MDANSE.Framework.AtomMapping import get_element_from_mapping, AtomLabel

from .FileWithAtomDataConfigurator import FileWithAtomDataConfigurator


Expand Down Expand Up @@ -114,20 +114,16 @@ def parse(self):

first = last + 1

def get_atom_labels(self) -> list[AtomLabel]:
def atom_labels(self) -> Iterable[AtomLabel]:
"""
Returns
-------
list[AtomLabel]
An ordered list of atom labels.
Yields
------
AtomLabel
An atom label.
"""
labels = []
for mol_name, _, atomic_contents, masses, _ in self["molecules"]:
for atm_label, mass in zip(atomic_contents, masses):
label = AtomLabel(atm_label, molecule=mol_name, mass=mass)
if label not in labels:
labels.append(label)
return labels
yield AtomLabel(atm_label, molecule=mol_name, mass=mass)

def get_atom_charges(self) -> np.ndarray:
"""Returns an array of partial electric charges
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
from typing import Iterable
from abc import abstractmethod
import traceback

Expand All @@ -37,13 +38,27 @@ def configure(self, filepath: str) -> None:
self.parse()
except Exception as e:
self.error_status = f"File parsing error {e}: {traceback.format_exc()}"
return

self.labels = self.unique_labels()
if len(self.labels) == 0:
self.error_status = f"Unable to generate atom labels"
return

@abstractmethod
def parse(self) -> None:
"""Parse the file."""
pass

@abstractmethod
def get_atom_labels(self) -> list[AtomLabel]:
"""Return the atoms labels in the file."""
pass
def atom_labels(self) -> Iterable[AtomLabel]:
"""Yields atom labels"""

def unique_labels(self) -> list[AtomLabel]:
"""
Returns
-------
list[AtomLabel]
An ordered list of atom labels.
"""
return list(set(self.atom_labels()))
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
import ast
import os
from typing import Union

import MDAnalysis as mda

from MDANSE import PLATFORM
from MDANSE.Framework.Configurators.IConfigurator import IConfigurator
from .MultiInputFileConfigurator import MultiInputFileConfigurator


class CoordinateFileConfigurator(IConfigurator):
class MDAnalysisCoordinateFileConfigurator(MultiInputFileConfigurator):

_default = ("", "AUTO")

Expand All @@ -41,46 +39,7 @@ def configure(self, setting: tuple[Union[str, list], str]):
string of the coordinate file format.
"""
values, format = setting

self["values"] = self._default
self._original_input = values

if type(values) is str:
if values:
try:
values = ast.literal_eval(values)
except (SyntaxError, ValueError) as e:
self.error_status = f"Unable to evaluate string: {e}"
return
if type(values) is not list:
self.error_status = (
f"Input values should be able to be evaluated as a list"
)
return
else:
values = []

if type(values) is list:
if not all([type(value) is str for value in values]):
self.error_status = f"Input values should be a list of str"
return
else:
self.error_status = f"Input values should be able to be evaluated as a list"
return

values = [PLATFORM.get_path(value) for value in values]

none_exist = []
for value in values:
if not os.path.isfile(value):
none_exist.append(value)

if none_exist:
self.error_status = f"The files {', '.join(none_exist)} do not exist."
return

self["values"] = values
self["filenames"] = values
super().configure(values)

if format == "AUTO" or not self["filenames"]:
self["format"] = None
Expand All @@ -95,15 +54,15 @@ def configure(self, setting: tuple[Union[str, list], str]):
if topology_configurator._valid:
try:
if len(self["filenames"]) <= 1 or self["format"] is None:
traj = mda.Universe(
_ = mda.Universe(
topology_configurator["filename"],
*self["filenames"],
format=self["format"],
topology_format=topology_configurator["format"],
).trajectory
else:
coord_files = [(i, self["format"]) for i in self["filenames"]]
traj = mda.Universe(
_ = mda.Universe(
topology_configurator["filename"],
coord_files,
topology_format=topology_configurator["format"],
Expand All @@ -116,8 +75,6 @@ def configure(self, setting: tuple[Union[str, list], str]):
self.error_status = "Requires valid topology file."
return

self.error_status = "OK"

@property
def wildcard(self):
return self._wildcard
Expand Down
Loading
Loading