Skip to content

Commit

Permalink
fix static methods; add scan bond length
Browse files Browse the repository at this point in the history
  • Loading branch information
rzhu committed Nov 29, 2024
1 parent 23949d3 commit 644c150
Showing 1 changed file with 55 additions and 4 deletions.
59 changes: 55 additions & 4 deletions src/molearn/analysis/analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ def get_dope(self, key, refine=True, **kwargs):
dataset = self.get_dataset(key)
decoded = self.get_decoded(key)

dope_dataset = self.get_all_dope_score(dataset, refine=refine, **kwargs)
dope_decoded = self.get_all_dope_score(decoded, refine=refine, **kwargs)
dataset_dope = self.get_all_dope_score(dataset, refine=refine, **kwargs)
decoded_dope = self.get_all_dope_score(decoded, refine=refine, **kwargs)

return dict(dataset_dope=dope_dataset, decoded_dope=dope_decoded)
return dict(dataset_dope=dataset_dope, decoded_dope=decoded_dope)

def get_ramachandran(self, key):
"""
Expand All @@ -259,6 +259,51 @@ def get_ramachandran(self, key):
)
return ramachandran

def get_bondlengths(self, key):
"""
Get backbone bond lengths of a dataset and its decoded counterpart.
"""
# Get the atomic indices to calculate different types of bond lengths
if set(['CA','C', 'N','CB']).issubset(set(self.atoms)):
indices = {'N-Ca':[], 'Ca-C':[], 'C-N':[], 'CA-CB':[]}
elif set(['CA','C', 'N']).issubset(set(self.atoms)):
indices = {'N-Ca':[], 'Ca-C':[], 'C-N':[]}
else:
raise ValueError(f"Selected atoms should contain at least N, CA, and C.")

mol_df = self.mol.data
for resid in mol_df.resid.unique():
resname = mol_df[mol_df['resid'] == resid].resname.unique()[0]

N_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'N')].index[0]
CA_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'CA')].index[0]
C_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'C')].index[0]
indices['N-Ca'].append((N_id, CA_id))
indices['Ca-C'].append((CA_id, C_id))
if resname != 'GLY' and 'CB' in self.atoms:
CB_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'CB')].index[0]
indices['Ca-Cb'].append((CA_id, CB_id))

if resid != len(mol_df.resid.unique()):
next_N_id = mol_df[(mol_df['resid'] == (resid+1)) & (mol_df['name'] == 'N')].index[0]
indices['C-N'].append((C_id, next_N_id))

# Look for the key in self._datasets and self._encoded
if key in self._datasets.keys():
dataset = self.get_dataset(key)*self.stdval + self.meanval
decoded = self.get_decoded(key)*self.stdval + self.meanval
dataset_bondlen = {k: MolearnAnalysis._bond_lengths(dataset, v) for k, v in indices.items()}
decoded_bondlen = {k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items()}
return dict(dataset_bondlen=dataset_bondlen, decoded_bondlen=decoded_bondlen)
elif key in self._encoded.keys():
decoded = self.get_decoded(key)*self.stdval + self.meanval
decoded_bondlen = {k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items()}
return dict(decoded_bondlen=decoded_bondlen)
else:
raise ValueError(f"Key {key} not found in _datasets or _encoded. Please load the dataset or setup a grid first.")


def setup_grid(self, samples=64, bounds_from=None, bounds=None, padding=0.1):
"""
Define a NxN point grid regularly sampling the latent space.
Expand Down Expand Up @@ -455,6 +500,7 @@ def _dope_score(self, frame, refine=True, **kwargs):

return self.dope_score_class.get_score(f * self.stdval, refine=refine, **kwargs)

@staticmethod
def _ca_chirality(N, CA, C, CB):
"""
Calculate chirality of Cα atom in a protein residue.
Expand All @@ -473,6 +519,11 @@ def _ca_chirality(N, CA, C, CB):
# L if dot_product > 0 else D
return dot_product

@staticmethod
def _bond_lengths(crds, indices):
bond_lengths = [np.linalg.norm(crds[:,:,i[0]] - crds[:,:,i[1]], axis=1) for i in indices]
return np.array(bond_lengths)

def get_all_ramachandran_score(self, tensor):
"""
Calculate Ramachandran score of an ensemble of atomic conrdinates.
Expand Down Expand Up @@ -606,7 +657,7 @@ def scan_ca_chirality(self):

# Get atom indices
indices = dict()
for resid in range(len(mol_df.resid.unique())):
for resid in mol_df.resid.unique():
resname = mol_df[mol_df['resid'] == resid].resname.unique()[0]
if not resname == 'GLY':
N_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'N')].index[0]
Expand Down

0 comments on commit 644c150

Please sign in to comment.