Skip to content

Commit

Permalink
Merge pull request #26 from Degiacomi-Lab/dev
Browse files Browse the repository at this point in the history
New features in analyser
  • Loading branch information
degiacom authored Dec 11, 2024
2 parents 7968cfc + 6bcdf39 commit d4700a4
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 37 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,5 @@ ERROR*
LOG*
*.swp
*.swo

archive/
272 changes: 236 additions & 36 deletions src/molearn/analysis/analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from copy import deepcopy
import numpy as np
import torch.optim
from pathlib import Path
from typing import Union

try:
# from modeller import *
Expand Down Expand Up @@ -87,11 +89,11 @@ def get_dataset(self, key):
"""
return self._datasets[key]

def set_dataset(self, key, data, atomselect="*"):
def set_dataset(self, key, data, atomselect="protein"):
"""
:param data: :func:`PDBData <molearn.data.PDBData>` object containing atomic coordinates
:param str key: label to be associated with data
:param list/str atomselect: list of atom names to load, or '*' to indicate that all atoms are loaded.
:param list/str atomselect: list of atom names to load, or 'protein' to indicate that all atoms are loaded.
"""
if isinstance(data, str) and data.endswith(".pdb"):
d = PDBData()
Expand Down Expand Up @@ -232,10 +234,10 @@ def get_dope(self, key, refine=True, **kwargs):
dataset = self.get_dataset(key)
decoded = self.get_decoded(key)

dope_dataset = self.get_all_dope_score(dataset, refine=refine, **kwargs)
dope_decoded = self.get_all_dope_score(decoded, refine=refine, **kwargs)
dataset_dope = self.get_all_dope_score(dataset, refine=refine, **kwargs)
decoded_dope = self.get_all_dope_score(decoded, refine=refine, **kwargs)

return dict(dataset_dope=dope_dataset, decoded_dope=dope_decoded)
return dict(dataset_dope=dataset_dope, decoded_dope=decoded_dope)

def get_ramachandran(self, key):
"""
Expand All @@ -257,6 +259,66 @@ def get_ramachandran(self, key):
)
return ramachandran

def get_bondlengths(self, key):
"""
Get backbone bond lengths of a dataset and its decoded counterpart.
"""
# Get the atomic indices to calculate different types of bond lengths
if set(["CA", "C", "N", "CB"]).issubset(set(self.atoms)):
indices = {"N-Ca": [], "Ca-C": [], "C-N": [], "CA-CB": []}
elif set(["CA", "C", "N"]).issubset(set(self.atoms)):
indices = {"N-Ca": [], "Ca-C": [], "C-N": []}
else:
raise ValueError("Selected atoms should contain at least N, CA, and C.")

mol_df = self.mol.data
for resid in mol_df.resid.unique():
resname = mol_df[mol_df["resid"] == resid].resname.unique()[0]

N_id = mol_df[(mol_df["resid"] == resid) & (mol_df["name"] == "N")].index[0]
CA_id = mol_df[(mol_df["resid"] == resid) & (mol_df["name"] == "CA")].index[
0
]
C_id = mol_df[(mol_df["resid"] == resid) & (mol_df["name"] == "C")].index[0]
indices["N-Ca"].append((N_id, CA_id))
indices["Ca-C"].append((CA_id, C_id))
if resname != "GLY" and "CB" in self.atoms:
CB_id = mol_df[
(mol_df["resid"] == resid) & (mol_df["name"] == "CB")
].index[0]
indices["Ca-Cb"].append((CA_id, CB_id))

if resid != len(mol_df.resid.unique()):
next_N_id = mol_df[
(mol_df["resid"] == (resid + 1)) & (mol_df["name"] == "N")
].index[0]
indices["C-N"].append((C_id, next_N_id))

# Look for the key in self._datasets and self._encoded
if key in self._datasets.keys():
dataset = self.get_dataset(key) * self.stdval + self.meanval
decoded = self.get_decoded(key) * self.stdval + self.meanval
dataset_bondlen = {
k: MolearnAnalysis._bond_lengths(dataset, v) for k, v in indices.items()
}
decoded_bondlen = {
k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items()
}
return dict(
dataset_bondlen=dataset_bondlen, decoded_bondlen=decoded_bondlen
)
elif key in self._encoded.keys():
decoded = self.get_decoded(key) * self.stdval + self.meanval
decoded_bondlen = {
k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items()
}
return dict(decoded_bondlen=decoded_bondlen)
else:
raise ValueError(
f"Key {key} not found in _datasets or _encoded. Please load the dataset or setup a grid first."
)

def setup_grid(self, samples=64, bounds_from=None, bounds=None, padding=0.1):
"""
Define a NxN point grid regularly sampling the latent space.
Expand Down Expand Up @@ -453,6 +515,51 @@ def _dope_score(self, frame, refine=True, **kwargs):

return self.dope_score_class.get_score(f * self.stdval, refine=refine, **kwargs)

def _chirality_whole(
self,
n: np.ndarray[tuple[int,], np.dtype[np.float64]],
ca: np.ndarray[tuple[int,], np.dtype[np.float64]],
c: np.ndarray[tuple[int,], np.dtype[np.float64]],
cb: np.ndarray[tuple[int,], np.dtype[np.float64]],
):
"""
check chirality for a set o amino acid
"""
ca_n = n - ca
ca_c = c - ca
cb_ca = cb - ca
normal = np.cross(ca_n, ca_c)
dot = np.einsum("ij,ij->i", normal, cb_ca)
# same but more calculations
# dot = np.diagonal(np.matmul(normal, cb_ca.T))
return dot

@staticmethod
def _ca_chirality(N, CA, C, CB):
"""
Calculate chirality of Cα atom in a protein residue.
:param N: Cartesian coordinates of N atom
:param CA: Cartesian coordinates of CA atom
:param C: Cartesian coordinates of C atom
:param CB: Cartesian coordinates of CB atom
:return: dot product of normal vector to the plane defined by N, CA, and C
"""
ca_c = C - CA
ca_cb = CB - CA
ca_n = N - CA
normal = np.cross(ca_n, ca_c)
dot_product = np.dot(normal, ca_cb)
# L if dot_product > 0 else D
return dot_product

@staticmethod
def _bond_lengths(crds, indices):
bond_lengths = [
np.linalg.norm(crds[:, :, i[0]] - crds[:, :, i[1]], axis=1) for i in indices
]
return np.array(bond_lengths)

def get_all_ramachandran_score(self, tensor):
"""
Calculate Ramachandran score of an ensemble of atomic conrdinates.
Expand Down Expand Up @@ -560,6 +667,68 @@ def scan_ramachandran(self):

return self.surfaces["Ramachandran_favored"], self.xvals, self.yvals

def scan_ca_chirality(self):
"""
Calculate chiralities of Cα atoms on a grid sampling the latent space.
Requires a grid system to be defined via a prior call to :func:`set_dataset <molearn.analysis.MolearnAnalysis.setup_grid>`.
Requires the atom selection includes Cα atoms.
Saves a surface in memory, with key 'CA_chirality'.
:return: Number of CA inversions on the latent space NxN surface
:return: x-axis values
:return: y-axis values
"""
assert set(["CA", "C", "N", "CB"]).issubset(
set(self.atoms)
), "Atom selection shoud at least include CA, C, N, and CB"

key = "Chirality"
if key not in self.surfaces:
assert (
"grid" in self._encoded
), "make sure to call MolearnAnalysis.setup_grid first"
decoded = self.get_decoded("grid")

mol_df = self.mol.data

# Get atom indices
indices = dict()
for resid in mol_df.resid.unique():
resname = mol_df[mol_df["resid"] == resid].resname.unique()[0]
if not resname == "GLY":
N_id = mol_df[
(mol_df["resid"] == resid) & (mol_df["name"] == "N")
].index[0]
C_id = mol_df[
(mol_df["resid"] == resid) & (mol_df["name"] == "C")
].index[0]
CA_id = mol_df[
(mol_df["resid"] == resid) & (mol_df["name"] == "CA")
].index[0]
CB_id = mol_df[
(mol_df["resid"] == resid) & (mol_df["name"] == "CB")
].index[0]
indices[resname + str(resid)] = (N_id, CA_id, C_id, CB_id)

idx = np.asarray(list(indices.values()))
results = []
for j in decoded:
s = (j.view(1, 3, -1).permute(0, 2, 1) * self.stdval).numpy().squeeze()
chir_test = self._chirality_whole(
s[idx[:, 0], :],
s[idx[:, 1], :],
s[idx[:, 2], :],
s[idx[:, 3], :],
)
wrong_chir = chir_test < 0
results.append(wrong_chir.sum())
results = np.asarray(results)
self.surfaces[key] = np.array(results).reshape(
self.n_samples, self.n_samples
)

return self.surfaces[key], self.xvals, self.yvals

def scan_custom(self, fct, params, key):
"""
Generate a surface coloured as a function of a user-defined function.
Expand All @@ -580,43 +749,74 @@ def scan_custom(self, fct, params, key):

return self.surfaces[key], self.xvals, self.yvals

def _relax(self, pdb_path: str, mini_path: str) -> None:
"""
relax generated structure
:param str pdb_path: path of the pdb file to relax
:param str mini_path: path where the new relaxed structure should be saved
"""
# read pdb
pdb = PDBFile(pdb_path)
# preparation
forcefield = ForceField("amber99sb.xml")
modeller = Modeller(pdb.topology, pdb.positions)
modeller.addHydrogens(forcefield)
system = forcefield.createSystem(modeller.topology, nonbondedMethod=NoCutoff)
integrator = VerletIntegrator(0.001 * picoseconds)
simulation = Simulation(modeller.topology, system, integrator)
simulation.context.setPositions(modeller.positions)
# minimize protein
simulation.minimizeEnergy(maxIterations=100)
positions = simulation.context.getState(getPositions=True).getPositions()
# write new pdb file
PDBFile.writeFile(simulation.topology, positions, open(mini_path, "w+"))
def _relax(
self,
pdb_file: Union[str, Path],
out_path: Union[str, Path],
maxIterations: int = 1000,
) -> None:
"""
Model the sidechains and relax generated structure
:param str pdb_file: path to the pdb file generated by the model
:param str out_path: path where the modelled/relaxed structures are be saved
"""

if not isinstance(pdb_file, str):
pdb_file = str(pdb_file)
if not isinstance(out_path, str):
out_path = str(out_path)

# Assume sidechain modelling is required if the number of selected atoms is fewer than 6
if len(self.atoms) < 6:
modelled_file = out_path + os.sep + (pdb_file.stem + "_modelled.pdb")
try:
env = Environ()
env.libs.topology.read(file="$(LIB)/top_heav.lib")
env.libs.parameters.read(file="$(LIB)/par.lib")

mdl = complete_pdb(env, str(pdb_file))
mdl.write(str(modelled_file))
pdb_file = modelled_file
except Exception as e:
print(f"Failed to model {pdb_file}\n{e}")
try:
relaxed_file = out_path + os.sep + (pdb_file.stem + "_relaxed.pdb")
# Read pdb
pdb = PDBFile(pdb_file)
# Add hydrogens
forcefield = ForceField("amber99sb.xml")
modeller = Modeller(pdb.topology, pdb.positions)
modeller.addHydrogens(forcefield)

system = forcefield.createSystem(
modeller.topology, nonbondedMethod=NoCutoff
)
integrator = VerletIntegrator(0.001 * picoseconds)
simulation = Simulation(modeller.topology, system, integrator)
simulation.context.setPositions(modeller.positions)
# Energy minimization
simulation.minimizeEnergy(maxIterations=maxIterations)
positions = simulation.context.getState(getPositions=True).getPositions()
# Write energy minimized file
PDBFile.writeFile(simulation.topology, positions, open(relaxed_file, "w+"))
except Exception as e:
print(f"Failed to relax {pdb_file}\n{e}")

def _pdb_file(
self,
prot_coords: np.ndarray[tuple[int, int], np.dtype[np.float64]],
outpath: str,
pdb_file: str,
) -> None:
"""
create pdb file for given coordinates
:param np.ndarray[tuple[int, int], np.dtype[np.float64]] prot_coords: coordinates of all atoms of a protein
:param str outpath: path where the pdb file should be stored
:param str pdb_file: path where the pdb file should be stored
"""
pdb_data = self.mol.data
with open(
outpath,
pdb_file,
"w+",
) as cfile:
for ck, k in enumerate(prot_coords):
Expand Down Expand Up @@ -661,14 +861,14 @@ def generate(
gen_prot_coords = s * self.stdval + self.meanval
# create pdb file
if pdb_path is not None:
for ci, i in enumerate(tqdm(gen_prot_coords, desc="Generating pdb file")):
struct_path = os.path.join(pdb_path, f"s{ci}.pdb")
self._pdb_file(i, struct_path)
for i, coord in enumerate(
tqdm(gen_prot_coords, desc="Generating pdb files")
):
struct_path = os.path.join(pdb_path, f"s{i}.pdb")
self._pdb_file(coord, struct_path)
# relax and save as new file
if relax:
self._relax(
struct_path, f"{os.path.splitext(struct_path)[0]}_relax.pdb"
)
self._relax(struct_path, pdb_path, maxIterations=1000)

return gen_prot_coords

Expand Down
2 changes: 1 addition & 1 deletion src/molearn/data/pdb_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def import_pdb(self, filename: str | list[str], topology: str | None = None):
first_universe = mda.Universe(filename[0])
self._mol = mda.Universe(first_universe._topology, filename)
if isinstance(filename, list) and topology is not None:
first_universe = mda.Universe(topology[0], filename[0])
first_universe = mda.Universe(topology, filename[0])
self._mol = mda.Universe(first_universe._topology, filename)
elif topology is None:
self._mol = mda.Universe(filename)
Expand Down

0 comments on commit d4700a4

Please sign in to comment.